AI | ML/Reinforcement Learning

[RL] stable-baselines3 모델 학습, 불러오기

깜태 2021. 7. 13. 14:07
728x90

아래의 코드는 기존의 gym library를 이용한 경우지만,
custom Env의 경우는 이전 글을 참조하면 된다.

https://tw0226.tistory.com/80

 

import gym

from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy


# Create environment
env = gym.make('LunarLander-v2')

# Instantiate the agent
model = DQN('MlpPolicy', env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(2e5))
# Save the agent
model.save("dqn_lunar")
del model  # delete trained model to demonstrate loading

# Load the trained agent
model = DQN.load("dqn_lunar", env=env)

# Evaluate the agent
# NOTE: If you use wrappers with your environment that modify rewards,
#       this will be reflected here. To evaluate with original rewards,
#       wrap environment in a "Monitor" wrapper before other wrappers.
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)

# Enjoy trained agent
obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, dones, info = env.step(action)
    env.render()

 

 

출처 : https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#accessing-and-modifying-model-parameters

728x90