This is a PyTorch implementation of Proximal Policy Optimization - PPO.
PPO is a policy gradient method for reinforcement learning. Simple policy gradient methods do a single gradient update per sample (or a set of samples). Doing multiple gradient steps for a single sample causes problems because the policy deviates too much, producing a bad policy. PPO lets us do multiple gradient updates per sample by trying to keep the policy close to the policy that was used to sample data. It does so by clipping gradient flow if the updated policy is not close to the policy used to sample the data.
You can find an experiment that uses it here. The experiment uses Generalized Advantage Estimation.
29import torch
30
31from labml_helpers.module import Module
32from labml_nn.rl.ppo.gae import GAE
Here's how the PPO update rule is derived.
We want to maximize policy reward where is the reward, is the policy, is a trajectory sampled from policy, and is the discount factor between .
So,
Define discounted-future state distribution,
Then,
Importance sampling from ,
Then we assume and are similar. The error we introduce to by this assumption is bound by the KL divergence between and . Constrained Policy Optimization shows the proof of this. I haven't read it.
35class ClippedPPOLoss(Module):
137 def __init__(self):
138 super().__init__()
140 def forward(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
141 advantage: torch.Tensor, clip: float) -> torch.Tensor:
ratio ; this is different from rewards .
144 ratio = torch.exp(log_pi - sampled_log_pi)
The ratio is clipped to be close to 1. We take the minimum so that the gradient will only pull towards if the ratio is not between and . This keeps the KL divergence between and constrained. Large deviation can cause performance collapse; where the policy performance drops and doesn't recover because we are sampling from a bad policy.
Using the normalized advantage introduces a bias to the policy gradient estimator, but it reduces variance a lot.
173 clipped_ratio = ratio.clamp(min=1.0 - clip,
174 max=1.0 + clip)
175 policy_reward = torch.min(ratio * advantage,
176 clipped_ratio * advantage)
177
178 self.clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
179
180 return -policy_reward.mean()
Similarly we clip the value function update also.
Clipping makes sure the value function doesn't deviate significantly from .
183class ClippedValueFunctionLoss(Module):
205 def forward(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
206 clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
207 vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
208 return 0.5 * vf_loss.mean()