ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [논문리뷰-DQN] Deep Q-Network
    Papers Review 2024. 5. 9. 22:20

     

     

    목차

      - 논문 : Playing Atari with Deep Reinforcement Learning (paper)
      - 핵심 내용 : 딥러닝 기반의 강화학습을 진행한다. 데이터를 랜덤하게 샘플링하는 방식과, Q-network 추가로 DQN 모델을 완성한다.

       

      사실 정확하게 말하면 논문 리뷰라기 보다는, DQN이 어떤 알고리즘인지 백그라운드와 핵심 포인트를 리뷰한 글.

       


       

      1. Reinforcement Learning

      • 일단 가장 먼저 강화학습이란 무엇인가?
        • 환경(Environment)과 상호작용하는 모델(Agent)를 학습시키는 것
        • 주어진 상태(State)에서 행동(Action)을 취하여 그에 상응하는 보상(Reward)를 받게 된다.
        • 최종적인 보상값을 최대화 하도록 행동을 강화한다.
        • 게임, 로보틱스, 자율주행 분야에서 활발히 사용된다
      • 예시: 마리오 게임을 한다고 가정한다면 ? 
        • Agent : 마리오
        • State : 현재 화면 픽셀 (마리오가 움직일 때마다 state는 바뀌는 것)
        • Actions : 앞으로 가기, 뒤로 가기, 점프
        • Rewards : 포인트 획득, 몬스터 만나서 사망(?)

      • DRL (Deep Reinforcement Learning) : 딥러닝 기반 강화학습 알고리즘
        • DeepMind에서 2013년에 DQN 이라는 DRL 모델을 발표.
        • Google DeepMind에서 2016년 알파고를 공개하며 DQN이 더욱 유명해짐.

       

      1-1. Q-learning

      딥러닝 이전의 강화학습 모델로 Q-learning 이란 것이 있다. 이는 가능한 State-Action 조합을 표로 그리고 보상인 Q-value를 구하여 가장 높은 값을 갖는 행동을 취하는 방식이다. 이때의 보상을 어떻게 구하는지 이해하려면 Markov Decision Process와 Bellman Equation을 알아야 한다.

      1-2. Markov Decision Process

      • 모든 상태는 오직 그 직전의 상태와 그때의 행동에 대해서만 의존한다는 가정
      • 모든 과거 상태를 고려할 수 없기 때문.
      • 0~t-1 까지의 과거시점 상태일 때, 현재시점 t일 확률  = t-1의 과거시점 상태일 때, 현재시점 t일 확률

      1-3. The Bellman Equation

      • 특정 상태에서 어떠한 행동을 취했을 때 받게 될 보상을 알고 있다면, 가장 높은 보상이 예상되는 행동을 선택하면 된다.
      • 현재 상태에서 행동을 취했을 때 받게 되는 즉각 보상 + 미래 상태에서 얻게될 보상 중 최댓값
      • 매 상태에서 현재+미래의 보상을 최대화할 수 있는 행동을 선택하는 방향으로 학습 된다.
      • Q(s, a) = 현재 상태 s에서 a라는 행동을 취했을 때 얻게 될 최대의 보상 값. Q-value
        • r(s, a) = 현재 상태 s에서 a라는 행동을 취했을 때 얻게 될 즉각 보상
        • gamma * maxQ(s', a) = 미래 상태 s'에서 취할 수 있는 행동(얻을 수 있는 보상) 중 최댓값
          gamma : 미래 가치 중요도를 조절하는 변수로 사용된다

       

      즉, 모든 상태는 그 직전의 상태만 알면 되므로, 딱 t 시점과 t+1의 시점만을 계산할 것이다. (Markov decision)

      계산(Q-value)은 t 시점에서 action 을 취했을 때의 즉각보상 + (미래가치 조절 변수)*미래에 받을 수 있는 경우의 수 중 최댓값 (Bellman Equation)으로 계산하면 된다.

      참고로 입실론-greedy 방식을 적용하여 explore하게 학습도 진행하지만, 오늘은 깊게 살펴보진 않겠다.

       

       

      • 예시

      Game Board의 빨간 자동차가 트로피 목표지점에 도달하기 위해 어떤 행동을 취해야 할지 알아보자.

      1. 먼저 Q-table을 그린다. 가능한 state(6개) * 취할 수 있는 action(4개)

      2. 현재 시점(s_t)에서 오른쪽으로 가게 될 때 얻을 Q-value 구하기

      - 즉각 보상(r) : 오른쪽으로 갔을 때 보상은 0

      - 미래 보상 중 최댓값 max Q(s') : 미래 시점(s_t+1) 에는 위로 올라가면 목표지점에 도달하므로 최대 보상값은 1이다.

      - gamma : 0.95

      - Q(s, a) = r(s, a) + gamma*maxQ(s', a) = 0 + 0.95*1 = 0.95

      3. Q-Table[3,2] 값을 0.95로 업데이트 해준다.

       

      2. DQN (Deep Q-Network)

      • 앞서 본 Q-table은 조합 가능한 state의 갯수가 많아질수록 무한정 테이블이 커지게 된다. 만약 1080*1080 pixel의 *3(rgb)라면 벌써 350만개의 열이 생성된다. 메모리가 무한의 크기만큼 커지게 된다.
      • 테이블 대신 Deep neural network인 CNN을 적용하여 학습하는 강화학습 방식을 DQN이라고 한다.
      • CNN을 이용하면 커다란 state vector(위의 예시처럼 350만개)가 들어와도, filter를 통해 사이즈를 줄일 수 있다.

       

      • 4개의 프레임 Input을 하나의 state로 보고, 해당 state에 대한 action의 갯수만큼 class를 만들어 예측한다. 마치 multi class classification!  즉, 350만개의 값이 들어온다 하더라도 최종적으로 출력하는 값은 N(action)개이다. 그리고 가장 높은 값을 갖는 Q(s,a) 를 선택하면 된다.
      • Input인 state vector는 엄청 크게 들어와도 되지만, action은 크지 않는 것이 강화학습에 좋겠다!

       

      • 장점
        • 기존 Q-learning은 조합 가능한 state 수만큼 테이블의 크기가 켜저서 학습이 어려웠으나, DQN은 CNN을 이용하여 feature vector를 축소시킬 수 있기 때문에 state vector가 커져도 학습이 가능하다.
        • 기존의 딥러닝은 hand-label 이 필요한 지도학습이었으나, 강화학습은 raw data 그대로 학습에 사용되므로 featur engineering에 대한 공수가 적으며, 도메인 지식이 적게 요구되는 End-to-end model 이라고 할 수 있다. 
      • 단점 (Naive DQN) 
        • 실제 Naive DQN(Q-table만 nn에 태운 모델) 과 단순한 linear 모델을 비교학습 하면 전자의 성능이 많이 떨어진다고 한다. 그 이유는
        • 첫째, 데이터 샘플들 사이에 Temporal correlations가 있기 때문이다. (-> Experience Replay = Replay buffer 로 해결)
        • 둘째, target이 고정되어 있지 않은 Non-stationary targets 문제를 갖는다. (-> Target network로 해결)

       

      2-1. Experience replay (Replay buffer)

      Online RL agents have temporal correlations between samples

      1. strongly temproally-corrrelated updates

      ex) 로봇이 걷도록 강화학습 시킨다고 할 때, 로봇이 넘어지기 직전의 스텝까지는 negative reward를 받지 않는다. 다음 스텝에 무조건 넘어진다 하더라도. 이처럼 데이터 샘플들 사이에는 시간적인 관계가 존재한다.

      2. Rapid forgetting rare experience

      ex) 로봇이 넘어졌다는 데이터는 굉장히 중요한 부분인데, 한 번만 학습되고 지나가게 되면 금방 해당 학습을 잊게 된다.

       

      즉, 데이터 샘플은 순서적으로 유기적인 관계를 맺고 있으며, 이는 위 예시처럼 단점으로 작용한다.

      이를 해결하기 위해 데이터 샘플을 임시로 저장해두었다가 샘플링하여 학습하는 방법이 experience replay이다.

       

      [ 방식 ]

      1. Q-learning 학습에 필요한 4-tuple: (s, a, r, s') 을 Replay buffer에 저장한다.

      2. 미니배치만큼 랜덤 샘플링하여 학습을 진행한다.

      3. buffer에는 maximum size가 있어 deque로 데이터가 쌓인다. (데이터 선입선출)

       

      [ 장점 ]

      1. 학습 도중 샘플들 사이의 temporal correlation을 줄일 수 있다.

      2. 미니 배치만큼 한 번에 학습되므로 빠르게 학습된다.

      3. rare한 샘플을 재사용하기 때문에 효용성이 높아진다.

       

       

       

      2-2. Target Network

      먼저 DQN의 cost function을 알아보자

      앞부분이 위에서 말한 Q-value predict 값이고, 뒷 부분(Q*)이 target 값이다. 두 사이의 loss를 감소시켜야 한다.

      이때 target 을 무엇으로 해야하는가? Q-table을 만들 수는 없다. 그래서 target value로 기존의 Q-value를 찾는 함수(CNN)를 그대로 태울 것이다. 이때 Q 함수는 replay buffer에서 데이터를 샘플링하여 학습할 때마다 weight가 업데이트 된다. 이렇게 되면 target Q값도 변한다. 이것을 non-stationary target problem이라고 한다.

       

      그러므로 변화하지 않고 고정된 weight를 사용하는 Target Q-network와, 파라미터를 학습하기 위한 Behavior Q-network를 따로 설계한다. 동일한 네트워크 구조이나 서로 다른 파라미터를 갖는 것이다.

      Behavior Q-network의 weight는 미니 배치마다 업데이트 되며,

      Target Q-network의 weight는 일정 학습이 지나면 Behavior Q-network의 weight로 교체된다.

       

      위의 수식에서 앞부분이 Behavior Q-network가 되고, 뒷 부분(Q*)이 Target Q-network가 되는 것이다.

       

      weigth 업데이트는 아래 수식처럼 loss 계산이 되고,

      아래 수식처럼 역전파(SGD) 된다.

       

       

      3. Architecture

      앞서 설명한 모든 것 :  1. CNN으로 Q-value 생성 2. replay buffer 3. target Q-network 을 도식화 하고, 수도코드를 작성한 것이다.

       

      1. behavior Q-network(CNN)를 랜덤한 weight로 생성한다.

      2. 동일한 구조와 weight를 갖는 target Q-network를 생성한다.

      3. R 사이즈 만큼의 deque 구조인 replay buffer를 생성한다.

      4. s_t: input state (사진이라고 할 경우, 84*84 픽셀을 flatten한 vector) 를 CNN에 태웠을 때, 가장 Q-value가 높은 a:행동 를 선택한다. (입실론 그리디 확률로 랜덤한 action을 선택). 

      5. s_t에서 a일 때, s_t+1에서 얻을 수 있는 최대치의 r:보상 을 계산한다.

      6. (s_t, a, r, s_t+1) 을 replay buffer에 저장한다. 크기가 R을 넘으면 가장 오래된 데이터를 삭제한다.

      7. 미니배치 갯수(B)보다 replay buffer가 더 많아지면 B만큼 랜덤하게 샘플링을 진행한다.

      8. Behavior Q-network에 태워서 y_pred값을 계산하고, Target Q-network에 태워서 y_target 값을 계산한다.

      9. 계산된 loss는 Behavior Q-network weight에 업데이트 한다.

      10. C step마다 Behavior Q-network weight을 Target Q-network weight에 복사한다.

      댓글

    Designed by Tistory.