AI | ML/Reinforcement Learning

[RL] 강화학습 모델 라이브러리 stable-baselines3 사용해보기

깜태 2021. 7. 13. 13:50
728x90

링크 : https://github.com/DLR-RM/stable-baselines3

 

DLR-RM/stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms. - DLR-RM/stable-baselines3

github.com

 

소개

강화학습을 학습시키려고 하면 모델들마다 사용법이 각각 달라서 많은 에러를 겪곤 한다.

하지만, stable-baselines3 라이브러리를 이용하면 아래의 장점이 생긴다.

1. 다양하게 구현된 모델들을 사용할 수 있고,

2. 경우에 따라 multi-processing 까지 지원된다.

 

일일이 모델마다 구현할 필요가 없어진다는 의미가 된다.

 

이제 사용법을 알아보자

 

1. Env 상속

stable-baselines3에서 사용하는 env는 Gym 라이브러리 기반의 Env를 사용한다.

 

따라서, 사용하려는 환경을 클래스화하고, gym.Env 를 상속받아야 한다.

 

나의 경우는 테트리스를 강화학습으로 시도해보고 있는데, 테트리스를 예로 들면 Env는 다음과 같이 정의된다.

 

class TetrisApp(gym.Env):
	def __init__(self):
    	self.height = 20
        self.width = 10
        self.observation_space = spaces.Box(low=0, high=1e+8, shape=(1, 4), dtype=np.float)
        self.action_space = spaces.Discrete(5)
        #  생략

2. State 정의하기 

강화학습을 진행하면서 Env는 State를 받아오기 위한 용도로 사용된다.

gym.Env에서는 self.observation_space와 self.action_space를 정의해야한다.

 

observation_space는 말그대로 관측되는 값을 담는 변수이며,

action_space는 행동을 어떻게 할지 정의하는 변수이다.

 

테트리스의 경우는 (Up, Down, Left, Right, Space) 로 각각 독립되어 있기때문에
gym.spaces 메소드 내 Discrete를 사용하였고,

observation_space는 [지워진 줄, 구멍 개수, column 별 높이 차의 합, column 별 쌓인 블록의 합]으로 정의하여서,

이런 값을 담을 수 있는 Box형태의 변수를 사용하였다.

 

(주의사항으로 제대로 정의되지 않은 경우, 에러가 발생한다.)

3. step, reset, render 정의하기

보통의 강화학습 환경에서는

에피소드가 갱신될 때마다 사용되는 env.reset() 메소드,

에피소드가 진행되는 중에 다음 상태를 불러오는 env.step() 메소드,

학습이 다 진행된 후나, 진행되고 있는 과정을 보기 위해 render() 메소드가 필요하다.

마찬가지로 gym.Env도 다음의 형태를 지니고 있어, step, reset, render 를 정의해야한다.

 

reset 메소드와 step 메소드의 차이는
reset의 경우 state만 반환하면 되는 반면에,

step의 경우는 state, reward, done, info 가 반환된다.

 

info의 경우는 dict형태로 비워놔도 학습이 진행된다.

 

나의 경우는 아래와 같이 정의하였다.

 

    def reset(self):
        self.init_game()
        states = self.feature_extraction()
        return states
        
    def init_game(self):
        self.board = new_board()
        self.score = 0
        self.lines = 0
        self.level = 1
        self.new_stone()
        pygame.time.set_timer(pygame.USEREVENT + 1, 1000)
        return self.board
    
    def step(self, action):
    	key_actions = {
            0: self.rotate_stone,
            1: lambda: self.move(-1),
            2: lambda: self.move(+1),
            3: self.insta_drop,
            4: lambda: self.drop(True),
            "ESCAPE": self.quit,
            "p": self.toggle_pause,
            "RETURN": self.start_game,
        }
        '''
        ~~생략
        '''
        infos = {}
        states = self.feature_extraction()
        return states, self.score, self.gameover, infos

 

4. 실행

 

그리고 실행하는 방법은 아주 쉽다.

위에서 만든 custom env 환경을 불러와서 실행하면 된다.

여기서 verbose는 0은 no output, 1이면 info, 2는 debug 모드를 의미한다.

추가로 GPU가 있는 경우 자동으로 GPU를 잡아주고, device='cpu' 옵션으로 cpu만 사용도 가능하다.

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from tetris_gym import TetrisApp
tetris_env = make_vec_env(TetrisApp, n_envs=8)
model = PPO('MlpPolicy', tetris_env, verbose=1)
model.learn(total_timesteps=(1e+6))
model.save("tetris")

 

5. 결과 확인

학습 결과가 다 나오면 마지막으로 zip 파일 형태로 파일이 저장된다.

 

참고: https://www.kaggle.com/ashleypaloalto2020/hungry-geese-self-play-agent-using-stable-baseli

https://github.com/DLR-RM/stable-baselines3

728x90