본문 바로가기
  • 기술을 이야기하지만 사람을 생각합니다.
20. 인공지능과 딥러닝

DQN 실습 :: CartPole 게임

by WE DONE IT. 2020. 8. 30.

Deep Q-Network 튜토리얼

클래식한 게임 CartPole에 Deep Q-Learning(Reinforcement learning)을 적용한 코드를 실습해 보았습니다.

<T 아카데미 - 강화학습 입문하기>와 다른 아티클 등을 참고하여 정리하였습니다. OpenAI Gym에서 Chartpole 게임에 DQN을 적용한 튜토리얼과 설명을 확인할 수 있습니다.

CartPole 게임 

 

Chartpole 게임 방식

Chartpole은 카트를 왼쪽 또는 오른쪽으로 잘 밀어서 균형을 잡는 문제이다. 이 게임에 DQN을 적용하여, remember와 replay를 반복하며 스스로 방법을 터득하게 된다.

  • reward function : 매 타임스텝마다 +1씩 보상을 받음
  • 막대가 중심에서 2.4유닛 이상 기울어지거나, 멀리 떨어지면 종료됨
    • State space : 현재 위치, 막도의 기울기, 속도, 막대의 각 속도 (숫자 4개의 4차원)
    • Action space : 오른쪽 또는 왼쪽으로 밂

신경망이 주요 목표는 target과 prediction의 격차를 줄이는 것이다.

 

Deep Q Network 시스템


GitHub 주소 https://github.com/seungeunrho/minimalRL/blob/master/dqn.py

01. Import & Hyperparameter setting

  • learning_rate : 목표와 예측 사이에 loss를 줄이기 위해 뉴럴 네트워크를 얼마나 반복할지 정하는 
  • gamma : 미래 discounted reward를 계산하기 위한 파라미터
  • exploration_rate : 에이전트가 초기에는 랜덤으로 action하며 경험을 습득하며, 언제쯤 탐험을 맡길지에 대한 값
  • exploration_decay : 스스로 학습하면서 나아지기 때문에, 탐험 횟수를 점차 감소시킴
  • episodes : 에이전트가 스스로 학습하기 위해 얼마나 게임을 플레이할 것인가에 대한 값

import gym
import collections
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
#Hyberparameters
learning_rate = 0.0005
gamma = 0.98
buffer_limit = 50000
batch_size = 32

  • Collections 라이브러리 : replay buffer에서 쓰일 deque를 import 하기 위함
  •  batch_size : transition 32개를 모아서 loss를 줄이도록 parameter를 업데이트함

 

02. Replay Buffer


class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)
    
    def put(self, transition):
        self.buffer.append(transition)
    
    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
        
        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)
    
    def size(self):
        return len(self.buffer)

  • ReplayBuffer( )
    • put( ) : 집어 넣는 method
      • buffer_limit : buffer 최대 크기
      • transition이 들어오면 buffer에 넣어주는 역할 
    • sample( ) : 빼내는 method
      • 32개를 뽑아 tensor로 만드는 과정 
      • buffer에서 32개를 랜덤으로 뽑아서 mini_batch_를 만듦
    • size( ) :  buffer에 몇 개 들어갔는지?

 

03. Q Network 


class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
      
    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,1)
        else : 
            return out.argmax().item()

  • Qnet( ): 뉴럴 네트워크
    • forward( )
      • Input : 4차원 -> [Fully Connected, Relu] -> H2 : 256차원 -> [Fully Connected] -> Q: 2차원 (왼쪽 or 오른쪽)
    • sample_action( ) : epsilon 그리는 과정 (8% -> 1% 줄어듦)
      • coin 0.05 보다 작을 확률 : 5% -> 랜덤, 나머지는 q-value가 가장 높은 것 뽑자

Q Network 구조

 

04. Train


def train(q, q_target, memory, optimizer):
    for i in range(10):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)

        q_out = q(s)
        q_a = q_out.gather(1,a)
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

  • 32개의 s_prime
  • q_out - Shape : [32,2]
  • q_a : 실제 한 액션들의 value. 취한 action의 q값만 골라냄 - Shape : [32, 1]
  • 딥러닝은 정답과 현재 값을 줄이도록 하는 게 목표!

 

05. Main


def main():
    env = gym.make('CartPole-v1')
    q = Qnet()
    q_target = Qnet()
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()

    print_interval = 20
    score = 0.0  
    optimizer = optim.Adam(q.parameters(), lr=learning_rate)

    for n_epi in range(10000):
        epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
        s = env.reset()
        done = False

        while not done:
            a = q.sample_action(torch.from_numpy(s).float(), epsilon)      
            s_prime, r, done, info = env.step(a)
            done_mask = 0.0 if done else 1.0
            memory.put((s,a,r/100.0,s_prime, done_mask))
            s = s_prime

            score += r
            if done:
                break
            
        if memory.size()>2000:
            train(q, q_target, memory, optimizer)

        if n_epi%print_interval==0 and n_epi!=0:
            q_target.load_state_dict(q.state_dict())
            print("n_episode :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format(
                                                            n_epi, score/print_interval, memory.size(), epsilon*100))
            score = 0.0
    env.close()

if __name__ == '__main__':
    main()

  • epsilon :  8% -> 1% 줄어듦
  • episod == 10,000
  • t == 600 
  • a : action을 뽑아 환경에 던져줌 -> 다음 state, reward, 에피소드 끝났는지? => 하나의 transition
    • r을 100으로 나누는 이유 : 스케일을 줄이기 위한 매직 넘버 ~('-')~
  • memory.size( ) > 2000
    : 에피소드를 메모리에 넣은 후, 메모리에 너무 샘플이 적으면 안 됨 2,000개 넘으면 학습 시작!

참고 자료

 

강화학습 입문하기 | T아카데미 온라인강의

1. 강화학습이란 무엇인지와 MP(마르코프 프로세스), 가치함수에 대해 알아본다. 2. DQN을 이용한 에이전트 학습 방법에 대해 알아보다.

tacademy.skplanet.com

 

 

My Journey Into Deep Q-Learning with Keras and Gym

This post will show you how to implement Deep Reinforcement Learning (Deep Q-Learning) applied to play an old Game: CartPole.

medium.com

 

 

Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 1.6.0 documentation

Note Click here to download the full example code Reinforcement Learning (DQN) Tutorial Author: Adam Paszke This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent on the CartPole-v0 task from the OpenAI Gym. Task The agent has to dec

pytorch.org

 

댓글