Files
Varuna Jayasiri 38f994a2c0 capsnet readme
2021-03-05 15:17:54 +05:30

353 lines
19 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="An annotated implementation of Proximal Policy Optimization (PPO) algorithm in PyTorch."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Proximal Policy Optimization (PPO)"/>
<meta name="twitter:description" content="An annotated implementation of Proximal Policy Optimization (PPO) algorithm in PyTorch."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/rl/ppo/index.html"/>
<meta property="og:title" content="Proximal Policy Optimization (PPO)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Proximal Policy Optimization (PPO)"/>
<meta property="og:description" content="An annotated implementation of Proximal Policy Optimization (PPO) algorithm in PyTorch."/>
<title>Proximal Policy Optimization (PPO)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/rl/ppo/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">rl</a>
<a class="parent" href="index.html">ppo</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/rl/ppo/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Proximal Policy Optimization (PPO)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of
<a href="https://arxiv.org/abs/1707.06347">Proximal Policy Optimization - PPO</a>.</p>
<p>PPO is a policy gradient method for reinforcement learning.
Simple policy gradient methods do a single gradient update per sample (or a set of samples).
Doing multiple gradient steps for a single sample causes problems
because the policy deviates too much, producing a bad policy.
PPO lets us do multiple gradient updates per sample by trying to keep the
policy close to the policy that was used to sample data.
It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.</p>
<p>You can find an experiment that uses it <a href="experiment.html">here</a>.
The experiment uses <a href="gae.html">Generalized Advantage Estimation</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">26</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">27</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>PPO Loss</h2>
<p>Here&rsquo;s how the PPO update rule is derived.</p>
<p>We want to maximize policy reward
<script type="math/tex; mode=display">\max_\theta J(\pi_\theta) =
\mathop{\mathbb{E}}_{\tau \sim \pi_\theta}\Biggl[\sum_{t=0}^\infty \gamma^t r_t \Biggr]</script>
where $r$ is the reward, $\pi$ is the policy, $\tau$ is a trajectory sampled from policy,
and $\gamma$ is the discount factor between $[0, 1]$.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\mathbb{E}_{\tau \sim \pi_\theta} \Biggl[
\sum_{t=0}^\infty \gamma^t A^{\pi_{OLD}}(s_t, a_t)
\Biggr] &=
\\
\mathbb{E}_{\tau \sim \pi_\theta} \Biggl[
\sum_{t=0}^\infty \gamma^t \Bigl(
Q^{\pi_{OLD}}(s_t, a_t) - V^{\pi_{OLD}}(s_t)
\Bigr)
\Biggr] &=
\\
\mathbb{E}_{\tau \sim \pi_\theta} \Biggl[
\sum_{t=0}^\infty \gamma^t \Bigl(
r_t + V^{\pi_{OLD}}(s_{t+1}) - V^{\pi_{OLD}}(s_t)
\Bigr)
\Biggr] &=
\\
\mathbb{E}_{\tau \sim \pi_\theta} \Biggl[
\sum_{t=0}^\infty \gamma^t \Bigl(
r_t
\Bigr)
\Biggr]
- \mathbb{E}_{\tau \sim \pi_\theta}
\Biggl[V^{\pi_{OLD}}(s_0)\Biggr] &=
J(\pi_\theta) - J(\pi_{\theta_{OLD}})
\end{align}</script>
</p>
<p>So,
<script type="math/tex; mode=display">\max_\theta J(\pi_\theta) =
\max_\theta \mathbb{E}_{\tau \sim \pi_\theta} \Biggl[
\sum_{t=0}^\infty \gamma^t A^{\pi_{OLD}}(s_t, a_t)
\Biggr]</script>
</p>
<p>Define discounted-future state distribution,
<script type="math/tex; mode=display">d^\pi(s) = (1 - \gamma) \sum_{t=0}^\infty \gamma^t P(s_t = s | \pi)</script>
</p>
<p>Then,
<script type="math/tex; mode=display">\begin{align}
J(\pi_\theta) - J(\pi_{\theta_{OLD}})
&= \mathbb{E}_{\tau \sim \pi_\theta} \Biggl[
\sum_{t=0}^\infty \gamma^t A^{\pi_{OLD}}(s_t, a_t)
\Biggr]
\\
&= \frac{1}{1 - \gamma}
\mathbb{E}_{s \sim d^{\pi_\theta}, a \sim \pi_\theta} \Bigl[
A^{\pi_{OLD}}(s, a)
\Bigr]
\end{align}</script>
</p>
<p>Importance sampling $a$ from $\pi_{\theta_{OLD}}$,</p>
<p>
<script type="math/tex; mode=display">\begin{align}
J(\pi_\theta) - J(\pi_{\theta_{OLD}})
&= \frac{1}{1 - \gamma}
\mathbb{E}_{s \sim d^{\pi_\theta}, a \sim \pi_\theta} \Bigl[
A^{\pi_{OLD}}(s, a)
\Bigr]
\\
&= \frac{1}{1 - \gamma}
\mathbb{E}_{s \sim d^{\pi_\theta}, a \sim \pi_{\theta_{OLD}}} \Biggl[
\frac{\pi_\theta(a|s)}{\pi_{\theta_{OLD}}(a|s)} A^{\pi_{OLD}}(s, a)
\Biggr]
\end{align}</script>
</p>
<p>Then we assume $d^\pi_\theta(s)$ and $d^\pi_{\theta_{OLD}}(s)$ are similar.
The error we introduce to $J(\pi_\theta) - J(\pi_{\theta_{OLD}})$
by this assumption is bound by the KL divergence between
$\pi_\theta$ and $\pi_{\theta_{OLD}}$.
<a href="https://arxiv.org/abs/1705.10528">Constrained Policy Optimization</a>
shows the proof of this. I haven&rsquo;t read it.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
J(\pi_\theta) - J(\pi_{\theta_{OLD}})
&= \frac{1}{1 - \gamma}
\mathop{\mathbb{E}}_{s \sim d^{\pi_\theta} \atop a \sim \pi_{\theta_{OLD}}} \Biggl[
\frac{\pi_\theta(a|s)}{\pi_{\theta_{OLD}}(a|s)} A^{\pi_{OLD}}(s, a)
\Biggr]
\\
&\approx \frac{1}{1 - \gamma}
\mathop{\mathbb{E}}_{\color{orange}{s \sim d^{\pi_{\theta_{OLD}}}}
\atop a \sim \pi_{\theta_{OLD}}} \Biggl[
\frac{\pi_\theta(a|s)}{\pi_{\theta_{OLD}}(a|s)} A^{\pi_{OLD}}(s, a)
\Biggr]
\\
&= \frac{1}{1 - \gamma} \mathcal{L}^{CPI}
\end{align}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">ClippedPPOLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">134</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">137</span> <span class="n">advantage</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>ratio $r_t(\theta) = \frac{\pi_\theta (a_t|s_t)}{\pi_{\theta_{OLD}} (a_t|s_t)}$;
<em>this is different from rewards</em> $r_t$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">ratio</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">-</span> <span class="n">sampled_log_pi</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<h3>Cliping the policy ratio</h3>
<p>
<script type="math/tex; mode=display">\begin{align}
\mathcal{L}^{CLIP}(\theta) =
\mathbb{E}_{a_t, s_t \sim \pi_{\theta{OLD}}} \biggl[
min \Bigl(r_t(\theta) \bar{A_t},
clip \bigl(
r_t(\theta), 1 - \epsilon, 1 + \epsilon
\bigr) \bar{A_t}
\Bigr)
\biggr]
\end{align}</script>
</p>
<p>The ratio is clipped to be close to 1.
We take the minimum so that the gradient will only pull
$\pi_\theta$ towards $\pi_{\theta_{OLD}}$ if the ratio is
not between $1 - \epsilon$ and $1 + \epsilon$.
This keeps the KL divergence between $\pi_\theta$
and $\pi_{\theta_{OLD}}$ constrained.
Large deviation can cause performance collapse;
where the policy performance drops and doesn&rsquo;t recover because
we are sampling from a bad policy.</p>
<p>Using the normalized advantage
$\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$
introduces a bias to the policy gradient estimator,
but it reduces variance a lot.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">clipped_ratio</span> <span class="o">=</span> <span class="n">ratio</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">clip</span><span class="p">,</span>
<span class="lineno">170</span> <span class="nb">max</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">clip</span><span class="p">)</span>
<span class="lineno">171</span> <span class="n">policy_reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">,</span>
<span class="lineno">172</span> <span class="n">clipped_ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">)</span>
<span class="lineno">173</span>
<span class="lineno">174</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_fraction</span> <span class="o">=</span> <span class="p">(</span><span class="nb">abs</span><span class="p">((</span><span class="n">ratio</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">))</span> <span class="o">&gt;</span> <span class="n">clip</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="lineno">175</span>
<span class="lineno">176</span> <span class="k">return</span> <span class="o">-</span><span class="n">policy_reward</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<h2>Clipped Value Function Loss</h2>
<p>Similarly we clip the value function update also.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
V^{\pi_\theta}_{CLIP}(s_t)
&= clip\Bigl(V^{\pi_\theta}(s_t) - \hat{V_t}, -\epsilon, +\epsilon\Bigr)
\\
\mathcal{L}^{VF}(\theta)
&= \frac{1}{2} \mathbb{E} \biggl[
max\Bigl(\bigl(V^{\pi_\theta}(s_t) - R_t\bigr)^2,
\bigl(V^{\pi_\theta}_{CLIP}(s_t) - R_t\bigr)^2\Bigr)
\biggr]
\end{align}</script>
</p>
<p>Clipping makes sure the value function $V_\theta$ doesn&rsquo;t deviate
significantly from $V_{\theta_{OLD}}$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span><span class="k">class</span> <span class="nc">ClippedValueFunctionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_return</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">201</span> <span class="n">clipped_value</span> <span class="o">=</span> <span class="n">sampled_value</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_value</span><span class="p">)</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">clip</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">clip</span><span class="p">)</span>
<span class="lineno">202</span> <span class="n">vf_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">clipped_value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">203</span> <span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">vf_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>