Jon is a first-year master’s student who is interested in reinforcement learning (RL). In his eyes, RL seemed fascinating because he could use RL libraries such as Stable-Baselines3 (SB3) to train agents to play all kinds of games.
He quickly recognized Proximal Policy Optimization (PPO) as a fast and versatile algorithm and wanted to implement PPO himself as a learning experience. Upon reading the paper, Jon thought to himself, “huh, this is pretty straightforward.” He then opened a code editor and started writing PPO.
CartPole-v1
from Gym was his chosen simulation environment, and before long, Jon made PPO work with CartPole-v1
. He had a great time and felt motivated to make his PPO work with more interesting environments, such as the Atari games and MuJoCo robotics tasks. “How cool would that be?” he thought.
However, he soon struggled. Making PPO work with Atari and MuJoCo seemed more challenging than anticipated. Jon then looked for reference implementations online but was shortly overwhelmed: unofficial repositories all appeared to do things differently, whereas he just could not read the Tensorflow 1.x
code in the official repo. Fortunately, Jon stumbled across two recent papers that explain PPO’s implementations. “This is it!” he grinned.
Failing to control his excitement, Jon started running around in the office, accidentally bumping into Sam, whom Jon knew was working on RL. They then had the following conversation:
MultiDiscrete
action space where you can use multiple discrete values to describe an action. Do you know how that works?”numpy, gym...
) and a neural network library (e.g., torch, jax,...
), could you code up PPO from scratch?”And the blog post is here! Instead of doing ablation studies and making recommendations on which details matter, this blog post takes a step back and focuses on reproductions of PPO’s results in all accounts. Specifically, this blog post complements prior work in the following ways:
openai/baselines
GitHub repository (the official repository for PPO). As we will show, the code in the openai/baselines
repository has undergone several refactorings which could produce different results from the original paper. So it is important to recognize which version of the official implementation is worth studying.MultiDiscrete
action spaces implementation detailOur ultimate purpose is to help people understand the PPO implementation through and through, reproduce past results with high fidelity, and facilitate customization for new research. To make research reproducible, we have made source code available at https://github.com/vwxyzjn/ppo-implementation-details.
PPO is a policy gradient algorithm proposed by Schulman et al., (2017). As a refinement to Trust Region Policy Optimization (TRPO) (Schulman et al., 2015), PPO uses a simpler clipped surrogate objective, omitting the expensive second-order optimization presented in TRPO. Despite this simpler objective, Schulman et al., (2017) show PPO has higher sample efficiency than TRPO in many control tasks. PPO also has good empirical performance in the arcade learning environment (ALE) which contain Atari games.
To facilitate more transparent research, Schulman et al., (2017) have made the source code of PPO available in the openai/baselines
GitHub repository with the code name pposgd
(commit da99706 on 7/20/2017). Later, the openai/baselines
maintainers have introduced a series of revisions. The key events include:
ppo2
and renamed pposgd
to ppo1
. According to a GitHub issue, one maintainer suggests ppo2
should offer better GPU utilization by batching observations from multiple simulation environments.ppo2
, producing the MuJoCo benchmarkppo2
, producing the Atari benchmarkopenai/baselines
to date. To our knowledge, ppo2
(ea25b9e) is the base of many PPO-related resources:
ppo2
(ea25b9e) closely.ppo2
(ea25b9e).In recent years, reproducing PPO’s results has become a challenging issue. The following table collects the best-reported performance of PPO in popular RL libraries in Atari and MuJoCo environments.
RL Library | GitHub Stars | Benchmark Source | Breakout | Pong | BeamRider | Hopper | Walker2d | HalfCheetah |
---|---|---|---|---|---|---|---|---|
Baselines pposgd / ppo1 (da99706) |
paper ($) | 274.8 | 20.7 | 1590 | ~2250 | ~3000 | ~1750 | |
Baselines ppo2 (7bfbcf1 and ea68f3b) |
docs (*) | 114.26 | 13.68 | 1299.25 | 2316.16 | 3424.95 | 1668.58 | |
Baselines ppo2 (ea25b9e) |
this blog post (*) | 409.265 ± 30.98 | 20.59 ± 0.40 | 2627.96 ± 625.751 | 2448.73 ± 596.13 | 3142.24 ± 982.25 | 2148.77 ± 1166.023 | |
Stable-Baselines3 | docs (0) (^) | 398.03 ± 33.28 | 20.98 ± 0.10 | 3397.00 ± 1662.36 | 2410.43 ± 10.02 | 3478.79 ± 821.70 | 5819.09 ± 663.53 | |
CleanRL | docs (1) (*) | ~402 | ~20.39 | ~2131 | ~2685 | ~3753 | ~1683 | |
Ray/RLlib | repo (2) (*) | 201 | - | 4480 | - | - | 9664 | |
SpinningUp | docs (3) (^) | - | - | - | ~2500 | ~2500 | ~3000 | |
ChainerRL | paper (4) (*) | - | - | - | 2719 ± 67 | 2994 ± 113 | 2404 ± 185 | |
Tianshou | paper (5) (^) | - | - | - | 7337.4 ± 1508.2 | 3127.7 ± 413.0 | 4895.6 ± 704.3 | |
Tonic | paper (6) (^) | - | - | - | ~2000 | ~4500 | ~5000 |
(-): No publicly reported metrics available
($): The experiments uses the v1 MuJoCo environments
(*): The experiments uses the v2 MuJoCo environments
(^): The experiments uses the v3 MuJoCo environments
(0): 1M steps for MuJoCo experiments, 10M steps for Atari games, 1 random seed
(1): 2M steps for MuJoCo experiments, 10M steps for Atari games, 2 random seeds
(2): 25M steps and 10 workers (5 envs per worker) for Atari experiments; 44M steps and 16 workers for MuJoCo experiments; 1 random seed
(3): 3M steps, PyTorch version, 10 random seeds
(4): 2M steps, 10 random seeds
(5): 3M steps, 10 random seeds
(6): 5M steps, 10 random seeds
We offer several observations.
openai/baselines
are not without performance consequences. Reproducing PPO’s results is challenging partly because even the original implementation could produce inconsistent results.ppo2
(ea25b9e) and libraries matching its implementation details have reported rather similar results. In comparison, other libraries have usually reported more diverse results.Despite the complicated situation, we have found ppo2
(ea25b9e) as an implementation worth studying. It obtains good performance in both Atari and MuJoCo tasks. More importantly, it also incorporates advanced features such as LSTM and treatment of the MultiDiscrete
action space, unlocking application to more complicated games such as Real-time Strategy games. As such, we define ppo2
(ea25b9e) as the official PPO implementation and base the remainder of this blog post on this implementation.
In this section, we introduce five categories of implementation details and implement them in PyTorch from scratch.
MultiDiscrete
implementation detailFor each category (except the first one), we benchmark our implementation against the original implementation in three environments, each with three random seeds.
We first introduce the 13 core implementation details commonly used regardless of the tasks. To help understand how to code these details in PyTorch, we have prepared a line-by-line video tutorial as follows. Note that the video tutorial skips over the 12-th and 13-th implementation details during its making, hence the video has the title “11 Core Implementation Details”
envs = VecEnv()
agent = Agent()
data = []
next_obs = envs.reset()
for update in range(1, total_timesteps // (N*M)):
# ROLLOUT PHASE
for step in range(0, M):
obs = next_obs
action, other_stuff = agent.get_action(obs)
next_obs, reward, done, info = envs.step(
action
) # step in N environments
data.append([obs, action, reward, done, other_stuff]) # store data
# LEARNING PHASE
agent.learn(data) # `len(data) = N*M`
num_envs
, and n_envs
. $M$ also has other names: the number of steps, the sampling horizon, nsteps
, and num_steps
. $N*M$ is also known as the fixed-length trajectory segments in the original PPO paper.In the simplest case, a vectorized environment corresponds to a single multiplayer game with $N$ players. If we run an RL algorithm in this environment, we are doing self-play without historical opponents. This setup can be straightforwardly extended to having $K$ concurrent games with $H$ players each, with $N = H*K$.
num_envs
(decision C1) and $M*N$ is the iteration_size
(decision C2) in Andrychowicz, et al. (2021), who suggest increasing $N$ (such as $N=256$) boosts the training throughput but makes the performance worse. They argued the performance deterioration was due to “shortened experience chunks” ($M$ becomes smaller due to the increase in $N$ in their setup ) and “earlier value bootstrapping.” While we agree increasing $N$ could hurt sample efficiency, we argue the evaluation should be based on wall-clock time efficiency. That is, if the algorithm terminates much sooner with a larger $N$ compared to other configurations, why not run the algorithm longer? Although being a different robotics simulator, Brax follows this idea and can train a viable agent in similar tasks with PPO using a massive $N = 2048$ and a small $M=20$ yet finish the training in one minute. env = Env()
agent = Agent()
data = []
for episode in range(1, num_episodes):
next_obs = env.reset()
for step in range(1, max_episode_horizon):
obs = next_obs
action, other_stuff = agent.get_action(obs)
next_obs, reward, done, info = env.step(action)
data.append([obs, action, reward, done, other_stuff]) # store data
if done:
break
agent.learn(data)
next_obs
is the 101st observation from these two environments, and the agent can keep doing rollouts and learn from the 101 to 200 steps from the 2 environments. Essentially, the agent learns partial trajectories of the episode, $M$ steps at a time.openai/baselines
library. The code for such initialization is in a2c/utils.py#L58, when in fact it is used for other algorithms such as PPO. In general, the weights of hidden layers use orthogonal initialization of weights with scaling np.sqrt(2)
, and the biases are set to 0
, as shown in the CNN initialization for Atari (common/models.py#L15-L26), and the MLP initialization for Mujoco (common/models.py#L75-L103). However, the policy output layer weights are initialized with the scale of 0.01
. The value output layer weights are initialized with the scale of 1
(common/policies.py#L49-L63).openai/baselines
(a2c/utils.py#L20-L35) is different from that of pytorch/pytorch (torch.nn.init.orthogonal_). However, we consider this to be a very low-level detail that should not impact the performance.PPO sets the epsilon parameter to 1e-5
, which is different from the default epsilon of 1e-8
in PyTorch and 1e-7
in TensorFlow. We list this implementation detail because the epsilon parameter is neither mentioned in the paper nor a configurable parameter in the PPO implementation. While this implementation detail may seem over specific,it is important that we match it for a high-fidelity reproduction.
Andrychowicz, et al. (2021) perform a grid search on Adam optimizer’s parameters (decision C24, C26, C28) and recommend $\beta_1 = 0.9$ and use the Tensorflow’s default epsilon parameter 1e-7
. Engstrom, Ilyas, et al., (2020) use the default PyTorch epsilon parameter 1e-8
.
2.5e-4
to 0
as the number of timesteps increases. In MuJoCo, the learning rate linearly decays from 3e-4
to 0
.gym
environments have a time limit and will truncate themselves if they run too long. For example, the CartPole-v1
has a 500 time limit (see link) and will return done=True
if the game lasts for more than 500 steps. While the PPO implementation does not estimate value of the terminal state in the truncated environments, we (intuitively) should. Nonetheless, for high-fidelity reproduction, we did not implement the correct handling for truncated environments.returns = advantages + values
, which corresponds to $TD(\lambda)$ and therefore not Monte Carlo for value estimation.PPO clips the value function like the PPO’s clipped surrogate objective. Given the V_{targ} = returns = advantages + values
, PPO fits the the value network by minimizing the following loss:
loss = policy_loss - entropy * entropy_coefficient + value_loss * value_coefficient
, which maximizes an entropy bonus term. Note that the policy parameters and value parameters share the same optimizer.0.5
.policy_loss
: the mean policy loss across all data points.value_loss
: the mean value loss across all data points.entropy_loss
: the mean entropy value across all data points.clipfrac
: the fraction of the training data that triggered the clipped objective.approxkl
: the approximate Kullback–Leibler divergence, measured by (-logratio).mean()
, which corresponds to the k1
estimator in John Schulman’s blog post on approximating KL divergence. This blog post also suggests using an alternative estimator ((ratio - 1) - logratio).mean()
, which is unbiased and has less variance. network = Sequential(
layer_init(Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
Tanh(),
layer_init(Linear(64, 64)),
Tanh(),
)
value_head = layer_init(Linear(64, 1), std=1.0)
policy_head = layer_init(Linear(64, envs.single_action_space.n), std=0.01)
hidden = network(observation)
value = value_head(hidden)
action = Categorical(policy_head(hidden)).sample()
value_network='copy'
argument. Then the pseudocode looks like this:
value_network = Sequential(
layer_init(Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
Tanh(),
layer_init(Linear(64, 64)),
Tanh(),
layer_init(Linear(64, 1), std=1.0),
)
policy_network = Sequential(
layer_init(Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
Tanh(),
layer_init(Linear(64, 64)),
Tanh(),
layer_init(Linear(64, envs.single_action_space.n), std=0.01),
)
value = value_network(observation)
action = Categorical(policy_network(observation)).sample()
We incorporate the first 12 details and the separate-networks architecture to produce a self-contained ppo.py
(link) that has 322 lines of code. Then, we make about 10 lines of code change to adopt the shared-network architecture, resulting in a self-contained ppo_shared.py
(link) that has 317 lines of code. The following shows the file difference between the ppo.py
(left) and ppo_shared.py
(right).
Below are the benchmarked results.
While shared-network architecture is the default setting in PPO, the separate-networks architecture clearly outperforms in simpler environments. The shared-network architecture performs worse probably due to the competing objectives of the policy and value functions. For this reason, we implement the separate-networks architecture in the video tutorial.
Next, we introduce the 9 Atari-specific implementation details. To help understand how to code these details in PyTorch, we have prepared a line-by-line video tutorial.
NoopResetEnv
(common/atari_wrappers.py#L12) Environment Preprocessing
NoopResetEnv
is a way to inject stochasticity to the environment.MaxAndSkipEnv
(common/atari_wrappers.py#L97) Environment Preprocessing
More precisely, the agent sees and selects actions on every $k$-th frame instead of every frame, and its last action is repeated on skipped frames. Because running the emulator forward for one step requires much less computation than having the agent select an action, this technique allows the agent to play roughly $k$ times more games without significantly increasing the runtime. We use $k=4$ for all games. […] First, to encode a single frame we take the maximum value for each pixel color value over the frame being encoded and the previous frame. This was necessary to remove flickering that is present in games where some objects appear only in even frames while other objects appear only in odd frames, an artifact caused by the limited number of sprites Atari 2600 can display at once.
EpisodicLifeEnv
(common/atari_wrappers.py#L61) Environment Preprocessing
For games where there is a life counter, the Atari 2600 emulator also sends the number of lives left in the game, which is then used to mark the end of an episode during training.
FireResetEnv
(common/atari_wrappers.py#L41) Environment Preprocessing
FIRE
action on reset for environments that are fixed until firing.WarpFrame
(Image transformation) common/atari_wrappers.py#L134 Environment Preprocessing
Second, we then extract the Y channel, also known as luminance, from the RGB frame and rescale it to 84x84.
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
ClipRewardEnv
(common/atari_wrappers.py#L125) Environment Preprocessing
{+1, 0, -1}
by its sign.As the scale of scores varies greatly from game to game, we clipped all positive rewards at 1 and all negative rewards at -1, leaving 0 rewards unchanged. Clipping the rewards in this manner limits the scale of the error derivatives and makes it easier to use the same learning rate across multiple games. At the same time, it could affect the performance of our agent since it cannot differentiate between rewards of different magnitude.
FrameStack
(common/atari_wrappers.py#L188) Environment Preprocessing
The function $\theta$ from algorithm 1 described below applies this preprocessing to the $m$ most recent frames and stacks them to produce the input to the Q-function, in which $m=4$.
hidden = Sequential(
layer_init(Conv2d(4, 32, 8, stride=4)),
ReLU(),
layer_init(Conv2d(32, 64, 4, stride=2)),
ReLU(),
layer_init(Conv2d(64, 64, 3, stride=1)),
ReLU(),
Flatten(),
layer_init(Linear(64 * 7 * 7, 512)),
ReLU(),
)
policy = layer_init(Linear(512, envs.single_action_space.n), std=0.01)
value = layer_init(Linear(512, 1), std=1)
policy = Sequential(
layer_init(Conv2d(4, 32, 8, stride=4)),
ReLU(),
layer_init(Conv2d(32, 64, 4, stride=2)),
ReLU(),
layer_init(Conv2d(64, 64, 3, stride=1)),
ReLU(),
Flatten(),
layer_init(Linear(64 * 7 * 7, 512)),
ReLU(),
layer_init(Linear(512, envs.single_action_space.n), std=0.01)
)
value = Sequential(
layer_init(Conv2d(4, 32, 8, stride=4)),
ReLU(),
layer_init(Conv2d(32, 64, 4, stride=2)),
ReLU(),
layer_init(Conv2d(64, 64, 3, stride=1)),
ReLU(),
Flatten(),
layer_init(Linear(64 * 7 * 7, 512)),
ReLU(),
layer_init(Linear(512, 1), std=1)
)
To run the experiments, we match the hyperparameters used in the original implementation as follows.
# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py
def atari():
return dict(
nsteps=128, nminibatches=4,
lam=0.95, gamma=0.99, noptepochs=4, log_interval=1,
ent_coef=.01,
lr=lambda f : f * 2.5e-4,
cliprange=0.1,
)
These hyperparameters are
nsteps
is the $M$ explained in this blog post .nminibatches
is the number of minibatches used for update (i.e., our 6th implementation detail).lam
is the GAE’s $\lambda$ parameter.gamma
is the discount factor.noptepochs
is the $K$ epochs in the original PPO paper.ent_coef
is the entropy_coefficient
in our 10th implementation detail.lr=lambda f : f * 2.5e-4
is a learning rate schedule (i.e., our 4th implementation detail)cliprange=0.1
is the clipping parameter $\epsilon$ in the original PPO paper.Note that the number of environments parameter $N$ (i.e., num_envs
) is set to the number of CPUs in the computer (common/cmd_util.py#L167), which is strange. We have chosen instead to match the N=8
used in the paper (the paper listed the parameter as “number of actors, 8”).
As shown below, we make ~40 lines of code change to ppo.py
to incorporate these 9 details, resulting in a self-contained ppo_atari.py
(link) that has 339 lines of code. The following shows the file difference between the ppo.py
(left) and ppo_atari.py
(right).
Below are the benchmarked results.
Next, we introduce the 9 details for continuous action domains such as MuJoCo tasks. To help understand how to code these details in PyTorch, we have prepared a line-by-line video tutorial. Note that the video tutorial skips over the 4-th implementation detail during its making, hence the video has the title “8 Details for Continuous Actions”
log std
is set to be state-independent and initialized to be 0. value_network = Sequential(
layer_init(Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
Tanh(),
layer_init(Linear(64, 64)),
Tanh(),
layer_init(Linear(64, 1), std=1.0),
)
policy_mean = Sequential(
layer_init(Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
Tanh(),
layer_init(Linear(64, 64)),
Tanh(),
layer_init(Linear(64, envs.single_action_space.n), std=0.01),
)
policy_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
value = value_network(observation)
probs = Normal(
policy_mean(x),
policy_logstd.expand_as(action_mean).exp(),
)
action = probs.sample()
logprob = probs.log_prob(action).sum(1)
VecNormalize
wrapper pre-processes the observation before feeding it to the PPO agent. The raw observation was normalized by subtracting its running mean and divided by its variance.VecNormalize
to a range, usually [−10, 10].VecNormalize
also applies a certain discount-based scaling scheme, where the rewards are divided by the standard deviation of a rolling discounted sum of the rewards (without subtracting and re-adding the mean).VecNormalize
to a range, usually [−10, 10].We make ~25 lines of code change to ppo.py
to incorporate these 9 details, resulting in a self-contained ppo_continuous_action.py
(link) that has 331 lines of code. The following shows the file difference between the ppo.py
(left) and ppo_continuous_action.py
(right).
To run the experiments, we match the hyperparameters used in the original implementation as follows.
# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py
def mujoco():
return dict(
nsteps=2048,
nminibatches=32,
lam=0.95,
gamma=0.99,
noptepochs=10,
log_interval=1,
ent_coef=0.0,
lr=lambda f: 3e-4 * f,
cliprange=0.2,
value_network='copy'
)
Note that value_network='copy'
means to use the separate MLP networks for policy and value functions (i.e., the 4th implementation detail in this section). Also, the number of environments parameter $N$ (i.e., num_envs
) is set to 1 (common/cmd_util.py#L167). Below are the benchmarked results.
Next, we introduce the 5 details for implementing LSTM.
std=1
and biases initialized with 0
.initial_lstm_state
before rollouts. During training, the agent then sequentially reconstruct the LSTM states based on the initial_lstm_state
. This process ensures that we reconstructed the probability distributions used in rollouts.We make ~60 lines of code change to ppo_atari.py
to incorporate these 5 details, resulting in a self-contained ppo_atari_lstm.py
(link) that has 385 lines of code. The following shows the file difference between the ppo_atari.py
(left) and ppo_atari_lstm.py
(right).
To run the experiments, we use the Atari hyperparameters again and remove the frame stack (i.e., setting the number of frames stacked to 1). Below are the benchmarked results.
MultiDiscrete
action space detailThe MultiDiscrete
space is often useful to describe action space for more complicated games. The Gym’s official documentation explains MultiDiscrete
action space as follows:
# https://github.com/openai/gym/blob/2af816241e4d7f41a000f6144f22e12c8231a112/gym/spaces/multi_discrete.py#L8-L25
class MultiDiscrete(Space):
"""
- The multi-discrete action space consists of a series of discrete action spaces with different number of actions in each
- It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space
- It is parametrized by passing an array of positive integers specifying number of actions for each discrete action space
Note: Some environment wrappers assume a value of 0 always represents the NOOP action.
e.g. Nintendo Game Controller
- Can be conceptualized as 3 discrete action spaces:
1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4
2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
- Can be initialized as
MultiDiscrete([ 5, 2, 2 ])
"""
...
Next, we introduce 1 detail for handling MultiDiscrete
action space:
MultiDiscrete
action spaces, the actions are represented with multiple discrete values. For example, the action of \(a_t = [a^1_t, a^2_t] = [0, 1]\) might mean to press the up arrow key and press button A. To account for this difference, PPO treats \([a^1_t, a^2_t]\) as probabilistically independent action components, therefore calculating \(prob(a_t) = prob(a^1_t) \cdot prob(a^2_t)\).MultiDiscrete
action spaces. For example, OpenAI Five’s action space is essentially MultiDiscrete([ 30, 4, 189, 81 ])
, as shown by the following quote:
All together this produces a combined factorized action space size of up to 30 × 4 × 189 × 81 = 1, 837, 080 dimensions
We make ~36 lines of code change to ppo_atari.py
to incorporate this 1 detail, resulting in a self-contained ppo_multidiscrete.py
(link) that has 335 lines of code. The following shows the file difference between the ppo_atari.py
(left) and ppo_multidiscrete.py
(right).
To run the experiments, we use the Atari hyperparameters again and use Gym-μRTS (Huang et al, 2021) as the simulation environment.
def gym_microrts():
return dict(
nsteps=128, nminibatches=4,
lam=0.95, gamma=0.99, noptepochs=4, log_interval=1,
ent_coef=.01,
lr=lambda f : f * 2.5e-4,
cliprange=0.1,
)
Below are the benchmarked results.
Next, we introduce 4 auxiliary techniques that are not used (by default) in the official PPO implementations but are potentially useful in special situations.
ppo1
and not used by default in ppo2
. Such as paradigm could improve training time by making use of all the available processes.
noptepochs
proposed in the original implementation by Schulman et al. (2017).--target-kl 0.01
), but toggled it off by default.Notably, we highlight the effect of invalid action masking. We make ~30 lines of code change to ppo_multidiscrete.py
to incorporate invalid action masking, resulting in a self-contained ppo_multidiscrete_mask.py
(link) that has 363 lines of code. The following shows the file difference between the ppo_multidiscrete.py
(left) and ppo_multidiscrete_mask.py
(right).
To run the experiments, we use the Atari hyperparameters again and use an older version of Gym-μRTS (Huang et al, 2021) as the simulation environment. Below are the benchmarked results.
As shown under each section, our implementations match the results of the original implementation closely. This close matching also extends to other metrics such as policy and value losses. We have made an interactive HTML below for interested viewers to compare other metrics:
During our reproduction, we have found a number of useful debugging techniques. They are as follows:
values.sum()
see if yours match the reference implementation). In the past, we have done this with the pytorch-a2c-ppo-acktr-gail repository and ultimately figured out a bug with our implementation.ratio=1
: Check if the ratio
are always 1s during the first epoch and first mini-batch update, when new and old policies are the same and therefore the ratio
are 1s and has nothing to clip. If ratio
are not 1s, it means there is a bug and the program has not reconstructed the probability distributions used in rollouts.approx_kl
stays below 0.02, and if approx_kl
becomes too high it usually means the policy is changing too quickly and there is a bug.
openai/baselines
’ PPO.If you are doing research using PPO, consider adopting the following recommendations to help improve the reproducibility of your work:
pip install -e .
, which 80% of the time would fail to run due to some obscure errors. Having a pre-built docker
image with all dependencies installed can also help in case the dependencies packages are not hosted by package managers after deprecation.matplotlib
and worrying about how to display data. Commercial solutions (usually more mature) include Weights and Biases and Neptune, and open-source solutions include Aim, ClearML, Polyaxon.ppo_atari.py
contains all relevant code to handle Atari games. Such a paradigm has the following benefits at the cost of duplicate and harder-to-refactor code:
env.py
, agent.py
, network.py
work together like in typical RL libraries.ppo.py
has significantly less LOC compared to RL libraries’ PPO. As a result, it’s often easier to prototype new features without having to do subclassing and refactoring.filediff
between the current and past versions, and every line of code change is made explicit to us.This blog post demonstrates reproducing PPO is a non-trivial effort, even though PPO’s source code is readily available for reference. Why is it the case? We think one important reason might be that modularity disperses implementation details.
Almost all RL libraries have adopted modular design, featuring different modules / files like env.py
, agent.py
, network.py
, utils.py
, runner.py
, etc. The nature of modularity necessarily puts implementation details into different files, which is usually great from a software engineering perspective. That is, we don’t have to know how other components work when we just work on env.py
. Being able to treat other components as black boxes has empowered us to work on large and complicated systems for the last decades.
However, this practice might clash hard with ML / RL: as the library grows, it becomes harder and harder to grasp all implementation details w.r.t an algorithm, whereas recognizing all implementation details has become increasingly important, as indicated by this blog post, Engstrom, Ilyas, et al., 2020, and Andrychowicz, et al., 2021. So what can we do?
Modular design still offers numerous benefits such as 1) easy-to-use interface, 2) integrated test cases, 3) easy to plug different components and others. To this end, good RL libraries are valuable, and we recommend them to write good documentation and refactor libraries to adopt new features. For algorithmic researchers, however, we recommend considering single-file implementations because they are straightforward to read and extend.
Not necessarily. The high-throughput variant Asynchronous PPO (APPO) (Berner et al., 2019) has obtained more attention in recent years. APPO eliminates the idle time in the original PPO implementation (e.g., have to wait for all $N$ environments to return observations), resulting in much higher throughput, GPU and CPU utilization. However, APPO involves performance-reducing side-effects, namely stale experiences (Espeholt et al., 2018), and we have found insufficient evidence to ascertain its improvement. The biggest issue is:
Underbenchmarked APPO implementation: RLlib has an APPO implementation, yet its documentation contains no benchmark information and suggest “APPO is not always more efficient; it is often better to use standard PPO or IMPALA.” Sample Factory (Petrenko et al, 2020) presents more benchmark results, but its support for Atari games is still a work in progress. To our knowledge, there is no APPO implementation that simultaneously works with Atari games, MuJoCo or Pybullet tasks, MultiDiscrete action spaces and with an LSTM.
While APPO is intuitively valuable for CPU-intensive tasks such as Dota 2, this blog post recommends an alternative approach to speed up PPO: make the vectorized environments really fast. Initially, the vectorized environments are implemented in python, which is slow. More recently, researchers have proposed to use accelerated vectorized environments. For example,
torch
to write hardware-accelerated vectorized environments, allowing the users to spin up $N=4096$ environments easily,Ant
in minutes compared to hours of training in MuJoCo.In the following section, we demonstrate accelerated training with PPO + envpool in the Atari game Pong.
Envpool is a recent work that offers accelerated vectorized environments for Atari by leveraging C++ and thread pools. Our PPO gets a free and side-effects-free performance boost by simply adopting it. We make ~60 lines of code change to ppo_atari.py
to incorporate this 1 detail, resulting in a self-contained ppo_atari_envpool.py
(link) that has 365 lines of code. The following shows the file difference between the ppo_atari.py
(left) and ppo_atari_envpool.py
(right).
As shown below, Envpool + PPO runs 3x faster without side effects (as in no loss of sample efficiency):
Two quick notes: 1) the performance deterioration in BeamRider is largely due to a degenerate random seed, and 2) Envpool uses the v5 ALE environments but has processed them the same way as the v4 ALE environments used in our previous experiments. Furthermore, by tuning the hyperparameters, we obtained a run that solves Pong in 5 mins. This performance is even comparable to IMPALA’s (Espeholt et al., 2018) results:
We think this raises a practical consideration: adopting async RL such as IMPALA could be more difficult than just making your vectorized environments fast.
Given this blog post, we believe the community understands PPO better and would be in a much better place to make improvements. Here are a few suggested areas for research.
Reproducing PPO’s results has been difficult in the past few years. While recent works conducted ablation studies to provide insight on the implementation details, these works are not structured as tutorials and only focus on details concerning robotics tasks. As a result, reproducing PPO from scratch can become a daunting experience. Instead of introducing additional improvements or doing further ablation studies, this blog post takes a step back and focuses on delivering a thorough reproduction of PPO in all accounts, as well as aggregating, documenting, and cataloging its most salient implementation details. This blog post also points out software engineering challenges in PPO and further efficiency improvement via the accelerated vectorized environments. With these, we believe this blog post will help people understand PPO faster and better, facilitating customization and research upon this versatile RL algorithm.
We thank Weights and Biases for providing a free academic license that helps us track the experiments. Shengyi would like to personally thank Angelica Pan, Scott Condron, Ivan Goncharov, Morgan McGuire, Jeremy Salwen, Cayla Sharp, Lavanya Shukla, and Aakarshan Chauhan for supporting him in making the video tutorials.
In this appendix, we introduce one detail for reproducing PPO’s results in the procgen environments (Cobbe et al, 2020).
We make ~60 lines of code change to ppo_atari.py
to incorporate these 5 details, resulting in a self-contained ppo_procgen.py
(link) that has 354 lines of code. The following shows the file difference between the ppo_atari.py
(left) and ppo_procgen.py
(right).
To run the experiment, we try to match the default setting in openai/train-procgen except that we use the easy
distribution mode and total_timesteps=25e6
to save compute.
def procgen():
return dict(
nsteps=256, nminibatches=8,
lam=0.95, gamma=0.999, noptepochs=3, log_interval=1,
ent_coef=.01,
lr=5e-4,
cliprange=0.2,
vf_coef=0.5, max_grad_norm=0.5,
)
network = build_impala_cnn(x, depths=[16,32,32], emb_size=256)
env = ProcgenEnv(
num_envs=64,
env_name="starpilot",
num_levels=0,
start_level=0,
distribution_mode="easy"
)
env = VecNormalize(venv=env, ob=False)
ppo2.learn(..., total_timesteps = 25_000_000)
Notice that
Below are the benchmarked results.
This post outlines a few more things you may need to know for creating and configuring your blog posts.