Files
2023-04-02 12:10:18 +05:30

1366 lines
123 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Annotated implementation to train a PPO agent on Atari Breakout game."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="PPO Experiment with Atari Breakout"/>
<meta name="twitter:description" content="Annotated implementation to train a PPO agent on Atari Breakout game."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/rl/ppo/experiment.html"/>
<meta property="og:title" content="PPO Experiment with Atari Breakout"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="PPO Experiment with Atari Breakout"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="PPO Experiment with Atari Breakout"/>
<meta property="og:description" content="Annotated implementation to train a PPO agent on Atari Breakout game."/>
<title>PPO Experiment with Atari Breakout</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/rl/ppo/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">rl</a>
<a class="parent" href="index.html">ppo</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rl/ppo/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>PPO Experiment with Atari Breakout</h1>
<p>This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym. It runs the <a href="../game.html">game environments on multiple processes</a> to sample efficiently.</p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/ppo/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span>
<span class="lineno">16</span>
<span class="lineno">17</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="lineno">18</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">optim</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">torch.distributions</span> <span class="kn">import</span> <span class="n">Categorical</span>
<span class="lineno">22</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">experiment</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span> <span class="n">IntDynamicHyperParam</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">labml_nn.rl.game</span> <span class="kn">import</span> <span class="n">Worker</span>
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo</span> <span class="kn">import</span> <span class="n">ClippedPPOLoss</span><span class="p">,</span> <span class="n">ClippedValueFunctionLoss</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p>Select device </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span><span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
<span class="lineno">32</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda:0&quot;</span><span class="p">)</span>
<span class="lineno">33</span><span class="k">else</span><span class="p">:</span>
<span class="lineno">34</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h2>Model</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">43</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>The first convolution layer takes a 84x84 frame and produces a 20x20 frame </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>The second convolution layer takes a 20x20 frame and produces a 9x9 frame </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>The third convolution layer takes a 9x9 frame and produces a 7x7 frame </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">7</span> <span class="o">*</span> <span class="mi">7</span> <span class="o">*</span> <span class="mi">64</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">512</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>A fully connected layer to get logits for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqm" style=""><span class="mord mathnormal" style="margin-right:0.03588em">π</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>A fully connected layer to get value function </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">obs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">72</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">obs</span><span class="p">))</span>
<span class="lineno">73</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
<span class="lineno">74</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv3</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
<span class="lineno">75</span> <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">7</span> <span class="o">*</span> <span class="mi">7</span> <span class="o">*</span> <span class="mi">64</span><span class="p">))</span>
<span class="lineno">76</span>
<span class="lineno">77</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lin</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
<span class="lineno">78</span>
<span class="lineno">79</span> <span class="n">pi</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
<span class="lineno">80</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">81</span>
<span class="lineno">82</span> <span class="k">return</span> <span class="n">pi</span><span class="p">,</span> <span class="n">value</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Scale observations from <code class="highlight"><span></span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">]</span></code>
to <code class="highlight"><span></span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span><span class="k">def</span> <span class="nf">obs_to_torch</span><span class="p">(</span><span class="n">obs</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">obs</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<h2>Trainer</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span><span class="k">class</span> <span class="nc">Trainer</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">96</span> <span class="n">updates</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">epochs</span><span class="p">:</span> <span class="n">IntDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">97</span> <span class="n">n_workers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">worker_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batches</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">98</span> <span class="n">value_loss_coef</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">99</span> <span class="n">entropy_bonus_coef</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">100</span> <span class="n">clip_range</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">101</span> <span class="n">learning_rate</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">102</span> <span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<h4>Configurations</h4>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>number of updates </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="bp">self</span><span class="o">.</span><span class="n">updates</span> <span class="o">=</span> <span class="n">updates</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>number of epochs to train the model with sampled data </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="bp">self</span><span class="o">.</span><span class="n">epochs</span> <span class="o">=</span> <span class="n">epochs</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>number of worker processes </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">=</span> <span class="n">n_workers</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>number of steps to run on each process for a single update </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span> <span class="o">=</span> <span class="n">worker_steps</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>number of mini batches </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="bp">self</span><span class="o">.</span><span class="n">batches</span> <span class="o">=</span> <span class="n">batches</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>total number of samples for a single update </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>size of a mini batch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">mini_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">batches</span>
<span class="lineno">119</span> <span class="k">assert</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">batches</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Value loss coefficient </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_loss_coef</span> <span class="o">=</span> <span class="n">value_loss_coef</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Entropy bonus coefficient </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="bp">self</span><span class="o">.</span><span class="n">entropy_bonus_coef</span> <span class="o">=</span> <span class="n">entropy_bonus_coef</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Clipping range </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_range</span> <span class="o">=</span> <span class="n">clip_range</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Learning rate </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="n">learning_rate</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<h4>Initialize</h4>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>create workers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="bp">self</span><span class="o">.</span><span class="n">workers</span> <span class="o">=</span> <span class="p">[</span><span class="n">Worker</span><span class="p">(</span><span class="mi">47</span> <span class="o">+</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>initialize tensors for observations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">84</span><span class="p">,</span> <span class="mi">84</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
<span class="lineno">138</span> <span class="k">for</span> <span class="n">worker</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">:</span>
<span class="lineno">139</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">send</span><span class="p">((</span><span class="s2">&quot;reset&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">))</span>
<span class="lineno">140</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">worker</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">):</span>
<span class="lineno">141</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">recv</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>optimizer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">2.5e-4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>GAE with <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.05556em;">γ</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0.99</span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">λ</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0.95</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">gae</span> <span class="o">=</span> <span class="n">GAE</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">,</span> <span class="mf">0.99</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>PPO Loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_loss</span> <span class="o">=</span> <span class="n">ClippedPPOLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Value Loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span> <span class="o">=</span> <span class="n">ClippedValueFunctionLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<h3>Sample data with current policy</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="n">rewards</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">164</span> <span class="n">actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="lineno">165</span> <span class="n">done</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
<span class="lineno">166</span> <span class="n">obs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">84</span><span class="p">,</span> <span class="mi">84</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
<span class="lineno">167</span> <span class="n">log_pis</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">168</span> <span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">169</span>
<span class="lineno">170</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>sample <code class="highlight"><span></span><span class="n">worker_steps</span></code>
from each worker </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p><code class="highlight"><span></span><span class="bp">self</span><span class="o">.</span><span class="n">obs</span></code>
keeps track of the last observation from each worker, which is the input for the model to sample the next action </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">175</span> <span class="n">obs</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>sample actions from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.680865em;vertical-align:-0.250305em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqm" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3567071428571427em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">O</span><span class="mord mathnormal mtight" style="">L</span><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14329285714285717em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.250305em;"><span></span></span></span></span></span></span></span></span></span></span></span> for each worker; this returns arrays of size <code class="highlight"><span></span><span class="n">n_workers</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="n">pi</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">obs_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">))</span>
<span class="lineno">179</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="lineno">180</span> <span class="n">a</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
<span class="lineno">181</span> <span class="n">actions</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="lineno">182</span> <span class="n">log_pis</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">a</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>run sampled actions on each worker </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="k">for</span> <span class="n">w</span><span class="p">,</span> <span class="n">worker</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">):</span>
<span class="lineno">186</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">send</span><span class="p">((</span><span class="s2">&quot;step&quot;</span><span class="p">,</span> <span class="n">actions</span><span class="p">[</span><span class="n">w</span><span class="p">,</span> <span class="n">t</span><span class="p">]))</span>
<span class="lineno">187</span>
<span class="lineno">188</span> <span class="k">for</span> <span class="n">w</span><span class="p">,</span> <span class="n">worker</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>get results after executing the actions </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">[</span><span class="n">w</span><span class="p">],</span> <span class="n">rewards</span><span class="p">[</span><span class="n">w</span><span class="p">,</span> <span class="n">t</span><span class="p">],</span> <span class="n">done</span><span class="p">[</span><span class="n">w</span><span class="p">,</span> <span class="n">t</span><span class="p">],</span> <span class="n">info</span> <span class="o">=</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">recv</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at <code class="highlight"><span></span><span class="n">Game</span></code>
to see how it works. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="k">if</span> <span class="n">info</span><span class="p">:</span>
<span class="lineno">196</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;reward&#39;</span><span class="p">,</span> <span class="n">info</span><span class="p">[</span><span class="s1">&#39;reward&#39;</span><span class="p">])</span>
<span class="lineno">197</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;length&#39;</span><span class="p">,</span> <span class="n">info</span><span class="p">[</span><span class="s1">&#39;length&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Get value of after the final step </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">obs_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">))</span>
<span class="lineno">201</span> <span class="n">values</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>calculate advantages </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">advantages</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gae</span><span class="p">(</span><span class="n">done</span><span class="p">,</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">samples</span> <span class="o">=</span> <span class="p">{</span>
<span class="lineno">208</span> <span class="s1">&#39;obs&#39;</span><span class="p">:</span> <span class="n">obs</span><span class="p">,</span>
<span class="lineno">209</span> <span class="s1">&#39;actions&#39;</span><span class="p">:</span> <span class="n">actions</span><span class="p">,</span>
<span class="lineno">210</span> <span class="s1">&#39;values&#39;</span><span class="p">:</span> <span class="n">values</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span>
<span class="lineno">211</span> <span class="s1">&#39;log_pis&#39;</span><span class="p">:</span> <span class="n">log_pis</span><span class="p">,</span>
<span class="lineno">212</span> <span class="s1">&#39;advantages&#39;</span><span class="p">:</span> <span class="n">advantages</span>
<span class="lineno">213</span> <span class="p">}</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>samples are currently in <code class="highlight"><span></span><span class="p">[</span><span class="n">workers</span><span class="p">,</span> <span class="n">time_step</span><span class="p">]</span></code>
table, we should flatten it for training </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">samples_flat</span> <span class="o">=</span> <span class="p">{}</span>
<span class="lineno">218</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">samples</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">219</span> <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:])</span>
<span class="lineno">220</span> <span class="k">if</span> <span class="n">k</span> <span class="o">==</span> <span class="s1">&#39;obs&#39;</span><span class="p">:</span>
<span class="lineno">221</span> <span class="n">samples_flat</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">obs_to_torch</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
<span class="lineno">222</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">223</span> <span class="n">samples_flat</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">224</span>
<span class="lineno">225</span> <span class="k">return</span> <span class="n">samples_flat</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<h3>Train the model based on samples</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">samples</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>It learns faster with a higher number of epochs, but becomes a little unstable; that is, the average episode reward does not monotonically increase over time. May be reducing the clipping range might solve it. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epochs</span><span class="p">()):</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>shuffle for each epoch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="n">indexes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randperm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>for each mini batch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">242</span> <span class="k">for</span> <span class="n">start</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mini_batch_size</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>get mini batch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">244</span> <span class="n">end</span> <span class="o">=</span> <span class="n">start</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">mini_batch_size</span>
<span class="lineno">245</span> <span class="n">mini_batch_indexes</span> <span class="o">=</span> <span class="n">indexes</span><span class="p">[</span><span class="n">start</span><span class="p">:</span> <span class="n">end</span><span class="p">]</span>
<span class="lineno">246</span> <span class="n">mini_batch</span> <span class="o">=</span> <span class="p">{}</span>
<span class="lineno">247</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">samples</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">248</span> <span class="n">mini_batch</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="p">[</span><span class="n">mini_batch_indexes</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>train </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">251</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_calc_loss</span><span class="p">(</span><span class="n">mini_batch</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Set learning rate </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="k">for</span> <span class="n">pg</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
<span class="lineno">255</span> <span class="n">pg</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Zero out the previously calculated gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>Calculate gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">259</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Clip gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">261</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Update parameters based on gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<h4>Normalize advantage function</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span> <span class="nd">@staticmethod</span>
<span class="lineno">266</span> <span class="k">def</span> <span class="nf">_normalize</span><span class="p">(</span><span class="n">adv</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">268</span> <span class="k">return</span> <span class="p">(</span><span class="n">adv</span> <span class="o">-</span> <span class="n">adv</span><span class="o">.</span><span class="n">mean</span><span class="p">())</span> <span class="o">/</span> <span class="p">(</span><span class="n">adv</span><span class="o">.</span><span class="n">std</span><span class="p">()</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<h3>Calculate total loss</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">270</span> <span class="k">def</span> <span class="nf">_calc_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">samples</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.00773em;">R</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.00773em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> returns sampled from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.680865em;vertical-align:-0.250305em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqm" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3567071428571427em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">O</span><span class="mord mathnormal mtight" style="">L</span><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14329285714285717em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.250305em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span> <span class="n">sampled_return</span> <span class="o">=</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;values&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;advantages&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.9701099999999999em;vertical-align:-0.15em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8201099999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">ˉ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.7954779999999997em;vertical-align:-0.6477389999999998em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.1477389999999998em;"><span style="top:-2.527261em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">σ</span><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqj" style=""><span class="mord accent mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.29634285714285713em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span style="top:-2.9523300000000003em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord mtight" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.485em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord accent mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.29634285714285713em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span style="top:-2.9523300000000003em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord mtight" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mbin mtight"></span><span class="mord mathnormal mtight">μ</span><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqj" style=""><span class="mord accent mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.29634285714285713em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span style="top:-2.9523300000000003em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord mtight" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.6477389999999998em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span>, where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0967699999999998em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord accent" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> is advantages sampled from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.680865em;vertical-align:-0.250305em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqm" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3567071428571427em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">O</span><span class="mord mathnormal mtight" style="">L</span><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14329285714285717em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.250305em;"><span></span></span></span></span></span></span></span></span></span></span></span>. Refer to sampling function in <a href="#main">Main class</a> below for the calculation of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0967699999999998em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">A</span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.11110999999999999em;"><span class="mord">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">sampled_normalized_advantage</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_normalize</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;advantages&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>Sampled observations are fed into the model to get <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord coloredeq eqm" style=""><span class="mord mathnormal" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord mathnormal" style="">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.664392em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqm" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>; we are treating observations as state </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">286</span> <span class="n">pi</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;obs&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqm" style=""><span class="mord mathnormal" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord mathnormal" style="">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord mathnormal" style="">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> are actions sampled from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.680865em;vertical-align:-0.250305em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqm" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3567071428571427em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">O</span><span class="mord mathnormal mtight" style="">L</span><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14329285714285717em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.250305em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">289</span> <span class="n">log_pi</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;actions&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>Calculate policy loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="n">policy_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_loss</span><span class="p">(</span><span class="n">log_pi</span><span class="p">,</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;log_pis&#39;</span><span class="p">],</span> <span class="n">sampled_normalized_advantage</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_range</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>Calculate Entropy Bonus</p>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">EB</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.80002em;vertical-align:-0.65002em;"></span><span class="mord mathbb">E</span><span class="mopen"><span class="delimsizing size2">[</span></span><span class="mord mathnormal" style="margin-right:0.05764em;">S</span><span class="mopen"><span class="delimsizing size1">[</span></span><span class="mord"><span class="mord coloredeq eqm" style=""><span class="mord mathnormal" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose"><span class="delimsizing size1">]</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mclose"><span class="delimsizing size2">]</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">298</span> <span class="n">entropy_bonus</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">entropy</span><span class="p">()</span>
<span class="lineno">299</span> <span class="n">entropy_bonus</span> <span class="o">=</span> <span class="n">entropy_bonus</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>Calculate value function loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">302</span> <span class="n">value_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;values&#39;</span><span class="p">],</span> <span class="n">sampled_return</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_range</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.07153em;">C</span><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.07847em;">I</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">P</span><span class="mbin mtight">+</span><span class="mord mathnormal mtight" style="margin-right:0.22222em;">V</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">F</span><span class="mbin mtight">+</span><span class="mord mathnormal mtight" style="margin-right:0.05017em;">EB</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.07153em;">C</span><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.07847em;">I</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">P</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">c</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.22222em;">V</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">F</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">c</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">EB</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">307</span> <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">policy_loss</span>
<span class="lineno">308</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_loss_coef</span><span class="p">()</span> <span class="o">*</span> <span class="n">value_loss</span>
<span class="lineno">309</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">entropy_bonus_coef</span><span class="p">()</span> <span class="o">*</span> <span class="n">entropy_bonus</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>for monitoring </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span> <span class="n">approx_kl_divergence</span> <span class="o">=</span> <span class="mf">.5</span> <span class="o">*</span> <span class="p">((</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;log_pis&#39;</span><span class="p">]</span> <span class="o">-</span> <span class="n">log_pi</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>Add to tracker </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">315</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="s1">&#39;policy_reward&#39;</span><span class="p">:</span> <span class="o">-</span><span class="n">policy_loss</span><span class="p">,</span>
<span class="lineno">316</span> <span class="s1">&#39;value_loss&#39;</span><span class="p">:</span> <span class="n">value_loss</span><span class="p">,</span>
<span class="lineno">317</span> <span class="s1">&#39;entropy_bonus&#39;</span><span class="p">:</span> <span class="n">entropy_bonus</span><span class="p">,</span>
<span class="lineno">318</span> <span class="s1">&#39;kl_div&#39;</span><span class="p">:</span> <span class="n">approx_kl_divergence</span><span class="p">,</span>
<span class="lineno">319</span> <span class="s1">&#39;clip_fraction&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_loss</span><span class="o">.</span><span class="n">clip_fraction</span><span class="p">})</span>
<span class="lineno">320</span>
<span class="lineno">321</span> <span class="k">return</span> <span class="n">loss</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<h3>Run training loop</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">323</span> <span class="k">def</span> <span class="nf">run_training_loop</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>last 100 episode information </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">329</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_queue</span><span class="p">(</span><span class="s1">&#39;reward&#39;</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">330</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_queue</span><span class="p">(</span><span class="s1">&#39;length&#39;</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">331</span>
<span class="lineno">332</span> <span class="k">for</span> <span class="n">update</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">updates</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>sample with current policy </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">334</span> <span class="n">samples</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>train the model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">337</span> <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">samples</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>Save tracked indicators. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">340</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>Add a new line to the screen periodically </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">342</span> <span class="k">if</span> <span class="p">(</span><span class="n">update</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">1_000</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">343</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<h3>Destroy</h3>
<p>Stop the workers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">345</span> <span class="k">def</span> <span class="nf">destroy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">350</span> <span class="k">for</span> <span class="n">worker</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">:</span>
<span class="lineno">351</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">send</span><span class="p">((</span><span class="s2">&quot;close&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">354</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>Create the experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">356</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;ppo&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>Configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">358</span> <span class="n">configs</span> <span class="o">=</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>Number of updates </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">360</span> <span class="s1">&#39;updates&#39;</span><span class="p">:</span> <span class="mi">10000</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>⚙️ Number of epochs to train the model with sampled data. You can change this while the experiment is running. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">363</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="n">IntDynamicHyperParam</span><span class="p">(</span><span class="mi">8</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<p>Number of worker processes </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">365</span> <span class="s1">&#39;n_workers&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>Number of steps to run on each process for a single update </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">367</span> <span class="s1">&#39;worker_steps&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>Number of mini batches </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">369</span> <span class="s1">&#39;batches&#39;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<p>⚙️ Value loss coefficient. You can change this while the experiment is running. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">372</span> <span class="s1">&#39;value_loss_coef&#39;</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">(</span><span class="mf">0.5</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>⚙️ Entropy bonus coefficient. You can change this while the experiment is running. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">375</span> <span class="s1">&#39;entropy_bonus_coef&#39;</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">(</span><span class="mf">0.01</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>⚙️ Clip range. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">377</span> <span class="s1">&#39;clip_range&#39;</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">(</span><span class="mf">0.1</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>You can change this while the experiment is running. ⚙️ Learning rate. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">380</span> <span class="s1">&#39;learning_rate&#39;</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1e-3</span><span class="p">)),</span>
<span class="lineno">381</span> <span class="p">}</span>
<span class="lineno">382</span>
<span class="lineno">383</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>Initialize the trainer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">386</span> <span class="n">m</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="o">**</span><span class="n">configs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>Run and monitor the experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">389</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">390</span> <span class="n">m</span><span class="o">.</span><span class="n">run_training_loop</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<p>Stop the workers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">392</span> <span class="n">m</span><span class="o">.</span><span class="n">destroy</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<h2>Run it</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">396</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="lineno">397</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>