This is a PyTorch implementation of Proximal Policy Optimization - PPO.
You can find an experiment that uses it here. The experiment uses Generalized Advantage Estimation.
17import torch
18
19from labml_helpers.module import Module
20from labml_nn.rl.ppo.gae import GAE
We want to maximize policy reward where $r$ is the reward, $\pi$ is the policy, $\tau$ is a trajectory sampled from policy, and $\gamma$ is the discount factor between $[0, 1]$.
So,
Define discounted-future state distribution,
Then,
Importance sampling $a$ from $\pi_{\theta_{OLD}}$,
Then we assume $d^\pi_\theta(s)$ and $d^\pi_{\theta_{OLD}}(s)$ are similar. The error we introduce to $J(\pi_\theta) - J(\pi_{\theta_{OLD}})$ by this assumtion is bound by the KL divergence between $\pi_\theta$ and $\pi_{\theta_{OLD}}$. Constrained Policy Optimization shows the proof of this. I haven’t read it.
23class ClippedPPOLoss(Module):
122 def __init__(self):
123 super().__init__()
125 def __call__(self, log_pi: torch.Tensor, sampled_log_pi: torch.Tensor,
126 advantage: torch.Tensor, clip: float) -> torch.Tensor:
ratio $r_t(\theta) = \frac{\pi_\theta (a_t|s_t)}{\pi_{\theta_{OLD}} (a_t|s_t)}$; this is different from rewards $r_t$.
129 ratio = torch.exp(log_pi - sampled_log_pi)
The ratio is clipped to be close to 1. We take the minimum so that the gradient will only pull $\pi_\theta$ towards $\pi_{\theta_{OLD}}$ if the ratio is not between $1 - \epsilon$ and $1 + \epsilon$. This keeps the KL divergence between $\pi_\theta$ and $\pi_{\theta_{OLD}}$ constrained. Large deviation can cause performance collapse; where the policy performance drops and doesn’t recover because we are sampling from a bad policy.
Using the normalized advantage $\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$ introduces a bias to the policy gradient estimator, but it reduces variance a lot.
156 clipped_ratio = ratio.clamp(min=1.0 - clip,
157 max=1.0 + clip)
158 policy_reward = torch.min(ratio * advantage,
159 clipped_ratio * advantage)
160
161 self.clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
162
163 return -policy_reward.mean()
Clipping makes sure the value function $V_\theta$ doesn’t deviate significantly from $V_{\theta_{OLD}}$.
166class ClippedValueFunctionLoss(Module):
185 def __call__(self, value: torch.Tensor, sampled_value: torch.Tensor, sampled_return: torch.Tensor, clip: float):
186 clipped_value = sampled_value + (value - sampled_value).clamp(min=-clip, max=clip)
187 vf_loss = torch.max((value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
188 return 0.5 * vf_loss.mean()