ppo colab

This commit is contained in:
Varuna Jayasiri
2021-03-30 12:33:46 +05:30
parent ac40d0a7c9
commit 50bd0556a5
10 changed files with 310 additions and 265 deletions

File diff suppressed because it is too large Load Diff

View File

@ -75,9 +75,10 @@
<h1>Generalized Advantage Estimation (GAE)</h1> <h1>Generalized Advantage Estimation (GAE)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of paper <p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of paper
<a href="https://arxiv.org/abs/1506.02438">Generalized Advantage Estimation</a>.</p> <a href="https://arxiv.org/abs/1506.02438">Generalized Advantage Estimation</a>.</p>
<p>You can find an experiment that uses it <a href="experiment.html">here</a>.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre></div> <div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
@ -88,7 +89,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">16</span><span class="k">class</span> <span class="nc">GAE</span><span class="p">:</span></pre></div> <div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">GAE</span><span class="p">:</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
@ -99,11 +100,11 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">17</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_workers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">worker_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">gamma</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">lambda_</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">19</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_workers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">worker_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">gamma</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">lambda_</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">18</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">=</span> <span class="n">lambda_</span> <span class="lineno">20</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">=</span> <span class="n">lambda_</span>
<span class="lineno">19</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">=</span> <span class="n">gamma</span> <span class="lineno">21</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">=</span> <span class="n">gamma</span>
<span class="lineno">20</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span> <span class="o">=</span> <span class="n">worker_steps</span> <span class="lineno">22</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span> <span class="o">=</span> <span class="n">worker_steps</span>
<span class="lineno">21</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">=</span> <span class="n">n_workers</span></pre></div> <span class="lineno">23</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">=</span> <span class="n">n_workers</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
@ -142,7 +143,7 @@ $\hat{A_t}$</p>
</p> </p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">23</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">done</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">rewards</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">values</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span></pre></div> <div class="highlight"><pre><span class="lineno">25</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">done</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">rewards</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">values</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-4'> <div class='section' id='section-4'>
@ -153,8 +154,8 @@ $\hat{A_t}$</p>
<p>advantages table</p> <p>advantages table</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">58</span> <span class="n">advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">57</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="mi">0</span></pre></div> <span class="lineno">59</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-5'> <div class='section' id='section-5'>
@ -165,9 +166,9 @@ $\hat{A_t}$</p>
<p>$V(s_{t+1})$</p> <p>$V(s_{t+1})$</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span> <div class="highlight"><pre><span class="lineno">62</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="lineno">61</span> <span class="lineno">63</span>
<span class="lineno">62</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">)):</span></pre></div> <span class="lineno">64</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">)):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
@ -178,9 +179,9 @@ $\hat{A_t}$</p>
<p>mask if episode completed after step $t$</p> <p>mask if episode completed after step $t$</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">mask</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">done</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <div class="highlight"><pre><span class="lineno">66</span> <span class="n">mask</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">done</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span>
<span class="lineno">65</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">last_value</span> <span class="o">*</span> <span class="n">mask</span> <span class="lineno">67</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">last_value</span> <span class="o">*</span> <span class="n">mask</span>
<span class="lineno">66</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="n">last_advantage</span> <span class="o">*</span> <span class="n">mask</span></pre></div> <span class="lineno">68</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="n">last_advantage</span> <span class="o">*</span> <span class="n">mask</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>
@ -191,7 +192,7 @@ $\hat{A_t}$</p>
<p>$\delta_t$</p> <p>$\delta_t$</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">delta</span> <span class="o">=</span> <span class="n">rewards</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="n">last_value</span> <span class="o">-</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span></pre></div> <div class="highlight"><pre><span class="lineno">70</span> <span class="n">delta</span> <span class="o">=</span> <span class="n">rewards</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="n">last_value</span> <span class="o">-</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-8'> <div class='section' id='section-8'>
@ -202,7 +203,7 @@ $\hat{A_t}$</p>
<p>$\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$</p> <p>$\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="n">delta</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">*</span> <span class="n">last_advantage</span></pre></div> <div class="highlight"><pre><span class="lineno">73</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="n">delta</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">*</span> <span class="n">last_advantage</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-9'> <div class='section' id='section-9'>
@ -219,11 +220,11 @@ The performance of the model was improving
probably because the samples are similar.</em></p> probably because the samples are similar.</em></p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">advantages</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">last_advantage</span> <div class="highlight"><pre><span class="lineno">82</span> <span class="n">advantages</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">last_advantage</span>
<span class="lineno">81</span>
<span class="lineno">82</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span>
<span class="lineno">83</span> <span class="lineno">83</span>
<span class="lineno">84</span> <span class="k">return</span> <span class="n">advantages</span></pre></div> <span class="lineno">84</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span>
<span class="lineno">85</span>
<span class="lineno">86</span> <span class="k">return</span> <span class="n">advantages</span></pre></div>
</div> </div>
</div> </div>
</div> </div>

View File

@ -85,12 +85,14 @@ It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.</p> is not close to the policy used to sample the data.</p>
<p>You can find an experiment that uses it <a href="experiment.html">here</a>. <p>You can find an experiment that uses it <a href="experiment.html">here</a>.
The experiment uses <a href="gae.html">Generalized Advantage Estimation</a>.</p> The experiment uses <a href="gae.html">Generalized Advantage Estimation</a>.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">26</span><span></span><span class="kn">import</span> <span class="nn">torch</span> <div class="highlight"><pre><span class="lineno">29</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">27</span> <span class="lineno">30</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span> <span class="lineno">31</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div> <span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
@ -195,7 +197,7 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
</p> </p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">ClippedPPOLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">ClippedPPOLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
@ -206,8 +208,8 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">136</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">134</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> <span class="lineno">137</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
@ -218,8 +220,8 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <div class="highlight"><pre><span class="lineno">139</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">137</span> <span class="n">advantage</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div> <span class="lineno">140</span> <span class="n">advantage</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-4'> <div class='section' id='section-4'>
@ -231,7 +233,7 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
<em>this is different from rewards</em> $r_t$.</p> <em>this is different from rewards</em> $r_t$.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">ratio</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">-</span> <span class="n">sampled_log_pi</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">143</span> <span class="n">ratio</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">-</span> <span class="n">sampled_log_pi</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-5'> <div class='section' id='section-5'>
@ -267,14 +269,14 @@ Large deviation can cause performance collapse;
but it reduces variance a lot.</p> but it reduces variance a lot.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">clipped_ratio</span> <span class="o">=</span> <span class="n">ratio</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">clip</span><span class="p">,</span> <div class="highlight"><pre><span class="lineno">172</span> <span class="n">clipped_ratio</span> <span class="o">=</span> <span class="n">ratio</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">clip</span><span class="p">,</span>
<span class="lineno">170</span> <span class="nb">max</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">clip</span><span class="p">)</span> <span class="lineno">173</span> <span class="nb">max</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">clip</span><span class="p">)</span>
<span class="lineno">171</span> <span class="n">policy_reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">,</span> <span class="lineno">174</span> <span class="n">policy_reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">,</span>
<span class="lineno">172</span> <span class="n">clipped_ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">)</span> <span class="lineno">175</span> <span class="n">clipped_ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">)</span>
<span class="lineno">173</span> <span class="lineno">176</span>
<span class="lineno">174</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_fraction</span> <span class="o">=</span> <span class="p">(</span><span class="nb">abs</span><span class="p">((</span><span class="n">ratio</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">))</span> <span class="o">&gt;</span> <span class="n">clip</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="lineno">177</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_fraction</span> <span class="o">=</span> <span class="p">(</span><span class="nb">abs</span><span class="p">((</span><span class="n">ratio</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">))</span> <span class="o">&gt;</span> <span class="n">clip</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="lineno">175</span> <span class="lineno">178</span>
<span class="lineno">176</span> <span class="k">return</span> <span class="o">-</span><span class="n">policy_reward</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div> <span class="lineno">179</span> <span class="k">return</span> <span class="o">-</span><span class="n">policy_reward</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
@ -300,7 +302,7 @@ V^{\pi_\theta}_{CLIP}(s_t)
significantly from $V_{\theta_{OLD}}$.</p> significantly from $V_{\theta_{OLD}}$.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">179</span><span class="k">class</span> <span class="nc">ClippedValueFunctionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">182</span><span class="k">class</span> <span class="nc">ClippedValueFunctionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>
@ -311,10 +313,10 @@ V^{\pi_\theta}_{CLIP}(s_t)
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_return</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">203</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_return</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">201</span> <span class="n">clipped_value</span> <span class="o">=</span> <span class="n">sampled_value</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_value</span><span class="p">)</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">clip</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">clip</span><span class="p">)</span> <span class="lineno">204</span> <span class="n">clipped_value</span> <span class="o">=</span> <span class="n">sampled_value</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_value</span><span class="p">)</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">clip</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">clip</span><span class="p">)</span>
<span class="lineno">202</span> <span class="n">vf_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">clipped_value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="lineno">205</span> <span class="n">vf_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">clipped_value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">203</span> <span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">vf_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div> <span class="lineno">206</span> <span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">vf_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
</div> </div>

View File

@ -85,6 +85,8 @@ It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.</p> is not close to the policy used to sample the data.</p>
<p>You can find an experiment that uses it <a href="https://nn.labml.ai/rl/ppo/experiment.html">here</a>. <p>You can find an experiment that uses it <a href="https://nn.labml.ai/rl/ppo/experiment.html">here</a>.
The experiment uses <a href="https://nn.labml.ai/rl/ppo/gae.html">Generalized Advantage Estimation</a>.</p> The experiment uses <a href="https://nn.labml.ai/rl/ppo/gae.html">Generalized Advantage Estimation</a>.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div> </div>
<div class='code'> <div class='code'>

View File

@ -699,6 +699,13 @@
</url> </url>
<url>
<loc>https://nn.labml.ai/rl/ppo/experiment.html</loc>
<lastmod>2021-03-30T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url> <url>
<loc>https://nn.labml.ai/rl/ppo/index.html</loc> <loc>https://nn.labml.ai/rl/ppo/index.html</loc>
<lastmod>2021-03-27T16:30:00+00:00</lastmod> <lastmod>2021-03-27T16:30:00+00:00</lastmod>
@ -722,7 +729,7 @@
<url> <url>
<loc>https://nn.labml.ai/rl/ppo/experiment.html</loc> <loc>https://nn.labml.ai/rl/ppo/experiment.html</loc>
<lastmod>2021-03-27T16:30:00+00:00</lastmod> <lastmod>2021-03-30T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>

View File

@ -21,6 +21,9 @@ is not close to the policy used to sample the data.
You can find an experiment that uses it [here](experiment.html). You can find an experiment that uses it [here](experiment.html).
The experiment uses [Generalized Advantage Estimation](gae.html). The experiment uses [Generalized Advantage Estimation](gae.html).
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f)
""" """
import torch import torch

View File

@ -86,7 +86,10 @@
"id": "-OnHLi626tJt" "id": "-OnHLi626tJt"
}, },
"source": [ "source": [
"Configurations" "### Configurations\n",
"\n",
"`IntDynamicHyperParam` and `FloatDynamicHyperParam` are dynamic hyper parameters\n",
"that you can change while the experiment is running."
] ]
}, },
{ {

View File

@ -8,6 +8,9 @@ summary: Annotated implementation to train a PPO agent on Atari Breakout game.
This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym.
It runs the [game environments on multiple processes](../game.html) to sample efficiently. It runs the [game environments on multiple processes](../game.html) to sample efficiently.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f)
""" """
from typing import Dict from typing import Dict
@ -354,23 +357,31 @@ def main():
experiment.create(name='ppo') experiment.create(name='ppo')
# Configurations # Configurations
configs = { configs = {
# number of updates # Number of updates
'updates': 10000, 'updates': 10000,
# number of epochs to train the model with sampled data # ⚙️ Number of epochs to train the model with sampled data.
# You can change this while the experiment is running.
# [![Example](https://img.shields.io/badge/example-hyperparams-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f/hyper_params)
'epochs': IntDynamicHyperParam(8), 'epochs': IntDynamicHyperParam(8),
# number of worker processes # Number of worker processes
'n_workers': 8, 'n_workers': 8,
# number of steps to run on each process for a single update # Number of steps to run on each process for a single update
'worker_steps': 128, 'worker_steps': 128,
# number of mini batches # Number of mini batches
'batches': 4, 'batches': 4,
# Value loss coefficient # ⚙️ Value loss coefficient.
# You can change this while the experiment is running.
# [![Example](https://img.shields.io/badge/example-hyperparams-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f/hyper_params)
'value_loss_coef': FloatDynamicHyperParam(0.5), 'value_loss_coef': FloatDynamicHyperParam(0.5),
# Entropy bonus coefficient # ⚙️ Entropy bonus coefficient.
# You can change this while the experiment is running.
# [![Example](https://img.shields.io/badge/example-hyperparams-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f/hyper_params)
'entropy_bonus_coef': FloatDynamicHyperParam(0.01), 'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
# Clip range # ⚙️ Clip range.
'clip_range': FloatDynamicHyperParam(0.1), 'clip_range': FloatDynamicHyperParam(0.1),
# Learning rate # You can change this while the experiment is running.
# [![Example](https://img.shields.io/badge/example-hyperparams-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f/hyper_params)
# ⚙️ Learning rate.
'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)), 'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
} }

View File

@ -8,6 +8,8 @@ summary: A PyTorch implementation/tutorial of Generalized Advantage Estimation (
This is a [PyTorch](https://pytorch.org) implementation of paper This is a [PyTorch](https://pytorch.org) implementation of paper
[Generalized Advantage Estimation](https://arxiv.org/abs/1506.02438). [Generalized Advantage Estimation](https://arxiv.org/abs/1506.02438).
You can find an experiment that uses it [here](experiment.html).
""" """
import numpy as np import numpy as np

View File

@ -14,3 +14,6 @@ is not close to the policy used to sample the data.
You can find an experiment that uses it [here](https://nn.labml.ai/rl/ppo/experiment.html). You can find an experiment that uses it [here](https://nn.labml.ai/rl/ppo/experiment.html).
The experiment uses [Generalized Advantage Estimation](https://nn.labml.ai/rl/ppo/gae.html). The experiment uses [Generalized Advantage Estimation](https://nn.labml.ai/rl/ppo/gae.html).
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f)