www.dlr.de · Antonin RAFFIN · Hands-on Session with Stable-Baselines3 · RLVS · 09.04.2021

Hands-on Session with Stable-Baselines3 (SB3)

Antonin RAFFIN (@araffin2)
German Aerospace Center (DLR)
https://araffin.github.io/

Outline

  1. Stable-Baselines3 Overview
  2. Questions?
  3. Hands-on Session

Stable-Baselines3 Overview

History of the project

To fork or not to fork? (2018)

build failed
hard-to-read
reddit 1
reddit 2

Stable-Baselines?


								from stable_baselines import A2C

								model = A2C("MlpPolicy", "CartPole-v1")
								model.learn(50000)
								model.save("a2c_cartpole")
							
PR OpenAI

Stable-Baselines (2018-2020)

Ashley Hill, Maximilian Ernestus, Adam Gleave, Anssi Kanervisto

https://github.com/hill-a/stable-baselines

  • 5 maintainers
  • 60+ contributors
  • 1000+ issues / pull requests
  • 300+ citations

Stable-Baselines3 (2020-...)

https://github.com/DLR-RM/stable-baselines3

  • cleaner codebase but same API
  • performance checked
  • code coverage: 95%
  • active community

Active Community

Stable-Baselines (SB2)

git stars
pypi

Stable-Baselines3 (SB3)

github active
pypi
contributtors

Design Principles

  • reliable implementations of RL algorithms
  • user-friendly
  • focus on model-free, single-agent RL
  • favour readability and simplicity over modularity

Features

  • algorithms: A2C, DDPG, DQN, HER, PPO, SAC and TD3
  • clean and simple interface
  • fully documented
  • comprehensive (tensorboard logging, callbacks, ...)
  • training framework included (RL Zoo)
  • SB3 Contrib: QR-DQN, TQC, ...

Getting Started


						import gym
						from stable_baselines3 import SAC
						# Train an agent using Soft Actor-Critic on Pendulum-v0
						env = gym.make("Pendulum-v0")
						model = SAC("MlpPolicy", env, verbose=1)
						# Train the model
						model.learn(total_timesteps=20000)
						# Save the model
						model.save("sac_pendulum")
						# Load the trained model
						model = SAC.load("sac_pendulum")
						# Start a new episode
						obs = env.reset()
						# What action to take in state `obs`?
						action, _ = model.predict(obs, deterministic=True)
					

SB3 Training loop

training loop

Training framework: RL Zoo

https://github.com/DLR-RM/rl-baselines3-zoo

  • training, loading, plotting, hyperparameter optimization
  • 100+ trained models + tuned hyperparameters

								# Train an A2C agent on Atari breakout using tuned hyperparameters,
								# evaluate the agent every 10k steps and save a checkpoint every 50k steps
								python train.py --algo a2c --env BreakoutNoFrameskip-v4 \
								    --eval-freq 10000 --save-freq 50000
								# Plot the learning curve
								python scripts/all_plots.py -a a2c -e BreakoutNoFrameskip-v4 -f logs/
							
neck target racing car bert

Recap

  • reliable RL implementations
  • user-friendly
  • training framework (RL Zoo)

Questions?

Upcoming: Hands-on session

https://github.com/araffin/rl-handson-rlvs21

  • getting started
  • Gym wrappers
  • callbacks
  • multiprocessing
  • importance of hyperparameters

Timer

Backup Slides

SB3 vs other libs

RL libs

SB3 policy

Network Architecture

Network Architecture

SB3 policy