📚 ppo intro

This commit is contained in:
Varuna Jayasiri
2021-02-23 17:46:27 +05:30
parent 5442dfb130
commit c1e9b0ce32
6 changed files with 370 additions and 41 deletions

View File

@ -75,14 +75,22 @@
<h1>Proximal Policy Optimization (PPO)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of
<a href="https://arxiv.org/abs/1707.06347">Proximal Policy Optimization - PPO</a>.</p>
<p>PPO is a policy gradient method for reinforcement learning.
Simple policy gradient methods one do a single gradient update per sample (or a set of samples).
Doing multiple gradient steps for a singe sample causes problems
because the policy deviates too much producing a bad policy.
PPO lets us do multiple gradient updates per sample by trying to keep the
policy close to the policy that was used to sample data.
It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.</p>
<p>You can find an experiment that uses it <a href="experiment.html">here</a>.
The experiment uses <a href="gae.html">Generalized Advantage Estimation</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">17</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">18</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
<div class="highlight"><pre><span class="lineno">26</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">27</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -91,6 +99,7 @@ The experiment uses <a href="gae.html">Generalized Advantage Estimation</a>.</p>
<a href='#section-1'>#</a>
</div>
<h2>PPO Loss</h2>
<p>Here&rsquo;s how the PPO update rule is derived.</p>
<p>We want to maximize policy reward
<script type="math/tex; mode=display">\max_\theta J(\pi_\theta) =
\mathop{\mathbb{E}}_{\tau \sim \pi_\theta}\Biggl[\sum_{t=0}^\infty \gamma^t r_t \Biggr]</script>
@ -186,7 +195,7 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">23</span><span class="k">class</span> <span class="nc">ClippedPPOLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">ClippedPPOLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -197,8 +206,8 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">123</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">134</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -209,8 +218,8 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">126</span> <span class="n">advantage</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">137</span> <span class="n">advantage</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -222,7 +231,7 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
<em>this is different from rewards</em> $r_t$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">ratio</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">-</span> <span class="n">sampled_log_pi</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">ratio</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">-</span> <span class="n">sampled_log_pi</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -230,7 +239,8 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>
<h3>Cliping the policy ratio</h3>
<p>
<script type="math/tex; mode=display">\begin{align}
\mathcal{L}^{CLIP}(\theta) =
\mathbb{E}_{a_t, s_t \sim \pi_{\theta{OLD}}} \biggl[
@ -257,14 +267,14 @@ Large deviation can cause performance collapse;
but it reduces variance a lot.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">clipped_ratio</span> <span class="o">=</span> <span class="n">ratio</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">clip</span><span class="p">,</span>
<span class="lineno">157</span> <span class="nb">max</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">clip</span><span class="p">)</span>
<span class="lineno">158</span> <span class="n">policy_reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">,</span>
<span class="lineno">159</span> <span class="n">clipped_ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">)</span>
<span class="lineno">160</span>
<span class="lineno">161</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_fraction</span> <span class="o">=</span> <span class="p">(</span><span class="nb">abs</span><span class="p">((</span><span class="n">ratio</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">))</span> <span class="o">&gt;</span> <span class="n">clip</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="lineno">162</span>
<span class="lineno">163</span> <span class="k">return</span> <span class="o">-</span><span class="n">policy_reward</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">clipped_ratio</span> <span class="o">=</span> <span class="n">ratio</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">clip</span><span class="p">,</span>
<span class="lineno">170</span> <span class="nb">max</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">clip</span><span class="p">)</span>
<span class="lineno">171</span> <span class="n">policy_reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">,</span>
<span class="lineno">172</span> <span class="n">clipped_ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">)</span>
<span class="lineno">173</span>
<span class="lineno">174</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_fraction</span> <span class="o">=</span> <span class="p">(</span><span class="nb">abs</span><span class="p">((</span><span class="n">ratio</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">))</span> <span class="o">&gt;</span> <span class="n">clip</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="lineno">175</span>
<span class="lineno">176</span> <span class="k">return</span> <span class="o">-</span><span class="n">policy_reward</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -273,6 +283,7 @@ Large deviation can cause performance collapse;
<a href='#section-6'>#</a>
</div>
<h2>Clipped Value Function Loss</h2>
<p>Similarly we clip the value function update also.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
V^{\pi_\theta}_{CLIP}(s_t)
@ -289,7 +300,7 @@ V^{\pi_\theta}_{CLIP}(s_t)
significantly from $V_{\theta_{OLD}}$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span><span class="k">class</span> <span class="nc">ClippedValueFunctionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">179</span><span class="k">class</span> <span class="nc">ClippedValueFunctionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -300,10 +311,10 @@ V^{\pi_\theta}_{CLIP}(s_t)
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_return</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">186</span> <span class="n">clipped_value</span> <span class="o">=</span> <span class="n">sampled_value</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_value</span><span class="p">)</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">clip</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">clip</span><span class="p">)</span>
<span class="lineno">187</span> <span class="n">vf_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">clipped_value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">188</span> <span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">vf_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_return</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">201</span> <span class="n">clipped_value</span> <span class="o">=</span> <span class="n">sampled_value</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_value</span><span class="p">)</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">clip</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">clip</span><span class="p">)</span>
<span class="lineno">202</span> <span class="n">vf_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">clipped_value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">203</span> <span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">vf_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
</div>

126
docs/rl/ppo/readme.html Normal file
View File

@ -0,0 +1,126 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Proximal Policy Optimization (PPO)"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/rl/ppo/readme.html"/>
<meta property="og:title" content="Proximal Policy Optimization (PPO)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Proximal Policy Optimization (PPO)"/>
<meta property="og:description" content=""/>
<title>Proximal Policy Optimization (PPO)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/rl/ppo/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">rl</a>
<a class="parent" href="index.html">ppo</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/rl/ppo/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/rl/ppo/index.html">Proximal Policy Optimization (PPO)</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of
<a href="https://arxiv.org/abs/1707.06347">Proximal Policy Optimization - PPO</a>.</p>
<p>PPO is a policy gradient method for reinforcement learning.
Simple policy gradient methods one do a single gradient update per sample (or a set of samples).
Doing multiple gradient steps for a singe sample causes problems
because the policy deviates too much producing a bad policy.
PPO lets us do multiple gradient updates per sample by trying to keep the
policy close to the policy that was used to sample data.
It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.</p>
<p>You can find an experiment that uses it <a href="https://nn.labml.ai/rl/ppo/experiment.html">here</a>.
The experiment uses <a href="https://nn.labml.ai/rl/ppo/gae.html">Generalized Advantage Estimation</a>.</p>
</div>
<div class='code'>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>

View File

@ -120,7 +120,7 @@
<url>
<loc>https://nn.labml.ai/normalization/batch_norm/readme.html</loc>
<lastmod>2021-02-02T16:30:00+00:00</lastmod>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -155,7 +155,7 @@
<url>
<loc>https://nn.labml.ai/index.html</loc>
<lastmod>2021-02-07T16:30:00+00:00</lastmod>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -302,7 +302,7 @@
<url>
<loc>https://nn.labml.ai/transformers/index.html</loc>
<lastmod>2021-02-07T16:30:00+00:00</lastmod>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -379,14 +379,14 @@
<url>
<loc>https://nn.labml.ai/transformers/switch/index.html</loc>
<lastmod>2021-02-10T16:30:00+00:00</lastmod>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/switch/readme.html</loc>
<lastmod>2021-02-01T16:30:00+00:00</lastmod>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -414,21 +414,35 @@
<url>
<loc>https://nn.labml.ai/transformers/mha.html</loc>
<lastmod>2021-02-17T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/compressive/index.html</loc>
<lastmod>2021-02-17T16:30:00+00:00</lastmod>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/compressive/experiment.html</loc>
<lastmod>2021-02-17T16:30:00+00:00</lastmod>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/compressive/index.html</loc>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/compressive/readme.html</loc>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/compressive/experiment.html</loc>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -442,21 +456,21 @@
<url>
<loc>https://nn.labml.ai/transformers/xl/experiment.html</loc>
<lastmod>2021-02-07T16:30:00+00:00</lastmod>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/xl/index.html</loc>
<lastmod>2021-02-10T16:30:00+00:00</lastmod>
<lastmod>2021-02-18T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/xl/readme.html</loc>
<lastmod>2021-02-07T16:30:00+00:00</lastmod>
<lastmod>2021-02-19T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>

View File

@ -0,0 +1,147 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Compressive Transformer"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/compressive/readme.html"/>
<meta property="og:title" content="Compressive Transformer"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Compressive Transformer"/>
<meta property="og:description" content=""/>
<title>Compressive Transformer</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/compressive/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">compressive</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/transformers/compressive/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/transformers/compressive/index.html">Compressive Transformer</a></h1>
<p>This is an implementation of
<a href="https://arxiv.org/abs/1911.05507">Compressive Transformers for Long-Range Sequence Modelling</a>
in <a href="https://pytorch.org">PyTorch</a>.</p>
<p>This is an extension of <a href="https://nn.labml.ai/transformers/xl/index.html">Transformer XL</a> where past memories
are compressed to give a longer attention range.
That is, the furthest $n_{cm} c$ memories are compressed into
$n_{cm}$ memories, where $c$ is the compression rate.</p>
<h2>Compression operation</h2>
<p>The compression operation is defined as
$f_c: \mathbb{R}^{nc \times d} \rightarrow \mathbb{R}^{n \times d}$.
The paper introduces multiple choices for $f_c$ and we have only implemented
1D convolution which seems to give the best results.
Each layer has a separate compression operation $f_c^{(i)}$ where
$i$ is the layer number.</p>
<h2>Training compression operation</h2>
<p>Since training compression with BPTT requires maintaining
a very large computational graph (many time steps), the paper proposes
an <em>auto-encoding loss</em> and an <em>attention reconstruction loss</em>.
The auto-encoding loss decodes the original memories from the compressed memories
and calculates the loss.
Attention reconstruction loss computes the multi-headed attention results
on the compressed memory and on uncompressed memory and gets a mean squared error
between them.
We have implemented the latter here since it gives better results.</p>
<p>This implementation uses pre-layer normalization
while the paper uses post-layer normalization.
Pre-layer norm does the layer norm before FFN[../feedforward.html) and
self-attention, and the pass-through in the residual connection is not normalized.
This is supposed to be more stable in standard transformer setups.</p>
<p>Here are <a href="https://nn.labml.ai/transformers/compressive/experiment.html">the training code</a> and a notebook for training a compressive transformer
model on the Tiny Shakespeare dataset.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/compressive/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://web.lab-ml.com/run?uuid=0d9b5338726c11ebb7c80242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>

View File

@ -10,6 +10,15 @@ summary: >
This is a [PyTorch](https://pytorch.org) implementation of
[Proximal Policy Optimization - PPO](https://arxiv.org/abs/1707.06347).
PPO is a policy gradient method for reinforcement learning.
Simple policy gradient methods one do a single gradient update per sample (or a set of samples).
Doing multiple gradient steps for a singe sample causes problems
because the policy deviates too much producing a bad policy.
PPO lets us do multiple gradient updates per sample by trying to keep the
policy close to the policy that was used to sample data.
It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.
You can find an experiment that uses it [here](experiment.html).
The experiment uses [Generalized Advantage Estimation](gae.html).
"""
@ -24,6 +33,8 @@ class ClippedPPOLoss(Module):
"""
## PPO Loss
Here's how the PPO update rule is derived.
We want to maximize policy reward
$$\max_\theta J(\pi_\theta) =
\mathop{\mathbb{E}}_{\tau \sim \pi_\theta}\Biggl[\sum_{t=0}^\infty \gamma^t r_t \Biggr]$$
@ -128,6 +139,8 @@ class ClippedPPOLoss(Module):
# *this is different from rewards* $r_t$.
ratio = torch.exp(log_pi - sampled_log_pi)
# ### Cliping the policy ratio
#
# \begin{align}
# \mathcal{L}^{CLIP}(\theta) =
# \mathbb{E}_{a_t, s_t \sim \pi_{\theta{OLD}}} \biggl[
@ -167,6 +180,8 @@ class ClippedValueFunctionLoss(Module):
"""
## Clipped Value Function Loss
Similarly we clip the value function update also.
\begin{align}
V^{\pi_\theta}_{CLIP}(s_t)
&= clip\Bigl(V^{\pi_\theta}(s_t) - \hat{V_t}, -\epsilon, +\epsilon\Bigr)

16
labml_nn/rl/ppo/readme.md Normal file
View File

@ -0,0 +1,16 @@
# [Proximal Policy Optimization (PPO)](https://nn.labml.ai/rl/ppo/index.html)
This is a [PyTorch](https://pytorch.org) implementation of
[Proximal Policy Optimization - PPO](https://arxiv.org/abs/1707.06347).
PPO is a policy gradient method for reinforcement learning.
Simple policy gradient methods one do a single gradient update per sample (or a set of samples).
Doing multiple gradient steps for a singe sample causes problems
because the policy deviates too much producing a bad policy.
PPO lets us do multiple gradient updates per sample by trying to keep the
policy close to the policy that was used to sample data.
It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.
You can find an experiment that uses it [here](https://nn.labml.ai/rl/ppo/experiment.html).
The experiment uses [Generalized Advantage Estimation](https://nn.labml.ai/rl/ppo/gae.html).