Files
Varuna Jayasiri 2038b11d29 ja translation
2023-05-10 17:00:29 -04:00

1366 lines
124 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="ja">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Atari Breakout ゲームで PPO エージェントをトレーニングするための注釈付き実装。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="アタリ・ブレイクアウトによるPPO実験"/>
<meta name="twitter:description" content="Atari Breakout ゲームで PPO エージェントをトレーニングするための注釈付き実装。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/rl/ppo/experiment.html"/>
<meta property="og:title" content="アタリ・ブレイクアウトによるPPO実験"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="アタリ・ブレイクアウトによるPPO実験"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="アタリ・ブレイクアウトによるPPO実験"/>
<meta property="og:description" content="Atari Breakout ゲームで PPO エージェントをトレーニングするための注釈付き実装。"/>
<title>アタリ・ブレイクアウトによるPPO実験</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/rl/ppo/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">rl</a>
<a class="parent" href="index.html">ppo</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rl/ppo/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>アタリ・ブレイクアウトによるPPO実験</h1>
<p>この実験では、OpenAI Gymでプロキシマルポリシー最適化PPOエージェントのAtariブレイクアウトゲームをトレーニングします。<a href="../game.html">ゲーム環境を複数のプロセスで実行して効率的にサンプリングします</a></p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/ppo/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span>
<span class="lineno">16</span>
<span class="lineno">17</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="lineno">18</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">optim</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">torch.distributions</span> <span class="kn">import</span> <span class="n">Categorical</span>
<span class="lineno">22</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">experiment</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span> <span class="n">IntDynamicHyperParam</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">labml_nn.rl.game</span> <span class="kn">import</span> <span class="n">Worker</span>
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo</span> <span class="kn">import</span> <span class="n">ClippedPPOLoss</span><span class="p">,</span> <span class="n">ClippedValueFunctionLoss</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p>デバイスを選択</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span><span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
<span class="lineno">32</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda:0&quot;</span><span class="p">)</span>
<span class="lineno">33</span><span class="k">else</span><span class="p">:</span>
<span class="lineno">34</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h2>モデル</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">43</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>最初の畳み込み層は 84 x 84 フレームで、20 x 20 フレームを生成します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>2 番目の畳み込み層は 20x20 フレームで、9x9 フレームを生成します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>3 番目の畳み込み層は 9x9 フレームで 7x7 フレームを生成します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>完全結合層は、3 番目の畳み込み層から平坦化されたフレームを取り出し、512 個の特徴を出力します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">7</span> <span class="o">*</span> <span class="mi">7</span> <span class="o">*</span> <span class="mi">64</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">512</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>ロジットを取得するための完全接続レイヤー <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqm" style=""><span class="mord mathnormal" style="margin-right:0.03588em">π</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>バリュー関数を得るための完全連結レイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">obs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">72</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">obs</span><span class="p">))</span>
<span class="lineno">73</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
<span class="lineno">74</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv3</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
<span class="lineno">75</span> <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">7</span> <span class="o">*</span> <span class="mi">7</span> <span class="o">*</span> <span class="mi">64</span><span class="p">))</span>
<span class="lineno">76</span>
<span class="lineno">77</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lin</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
<span class="lineno">78</span>
<span class="lineno">79</span> <span class="n">pi</span> <span class="o">=</span> <span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
<span class="lineno">80</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">81</span>
<span class="lineno">82</span> <span class="k">return</span> <span class="n">pi</span><span class="p">,</span> <span class="n">value</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p><code class="highlight"><span></span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">]</span></code>
観測値をからにスケーリング <code class="highlight"><span></span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span><span class="k">def</span> <span class="nf">obs_to_torch</span><span class="p">(</span><span class="n">obs</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">obs</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<h2>トレーナー</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span><span class="k">class</span> <span class="nc">Trainer</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">96</span> <span class="n">updates</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">epochs</span><span class="p">:</span> <span class="n">IntDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">97</span> <span class="n">n_workers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">worker_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batches</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">98</span> <span class="n">value_loss_coef</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">99</span> <span class="n">entropy_bonus_coef</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">100</span> <span class="n">clip_range</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">101</span> <span class="n">learning_rate</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">102</span> <span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<h4>コンフィギュレーション</h4>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>更新回数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="bp">self</span><span class="o">.</span><span class="n">updates</span> <span class="o">=</span> <span class="n">updates</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>サンプルデータを使用してモデルをトレーニングするエポックの数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="bp">self</span><span class="o">.</span><span class="n">epochs</span> <span class="o">=</span> <span class="n">epochs</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>ワーカープロセスの数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">=</span> <span class="n">n_workers</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>1 回の更新で各プロセスで実行するステップの数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span> <span class="o">=</span> <span class="n">worker_steps</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>ミニバッチ数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="bp">self</span><span class="o">.</span><span class="n">batches</span> <span class="o">=</span> <span class="n">batches</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>1 回の更新でのサンプルの総数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>ミニバッチのサイズ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">mini_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">batches</span>
<span class="lineno">119</span> <span class="k">assert</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">batches</span> <span class="o">==</span> <span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>価値損失係数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_loss_coef</span> <span class="o">=</span> <span class="n">value_loss_coef</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>エントロピーボーナス係数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="bp">self</span><span class="o">.</span><span class="n">entropy_bonus_coef</span> <span class="o">=</span> <span class="n">entropy_bonus_coef</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>クリッピング範囲</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_range</span> <span class="o">=</span> <span class="n">clip_range</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>学習率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="n">learning_rate</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<h4>[初期化]</h4>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>ワーカーを作成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="bp">self</span><span class="o">.</span><span class="n">workers</span> <span class="o">=</span> <span class="p">[</span><span class="n">Worker</span><span class="p">(</span><span class="mi">47</span> <span class="o">+</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>観測用のテンソルを初期化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">84</span><span class="p">,</span> <span class="mi">84</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
<span class="lineno">138</span> <span class="k">for</span> <span class="n">worker</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">:</span>
<span class="lineno">139</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">send</span><span class="p">((</span><span class="s2">&quot;reset&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">))</span>
<span class="lineno">140</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">worker</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">):</span>
<span class="lineno">141</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">recv</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>モデル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>オプティマイザー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">2.5e-4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>GATE (<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.05556em;">γ</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0.99</span></span></span></span></span>および付き) <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">λ</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0.95</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">gae</span> <span class="o">=</span> <span class="n">GAE</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">,</span> <span class="mf">0.99</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>PPO ロス</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_loss</span> <span class="o">=</span> <span class="n">ClippedPPOLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>価値損失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span> <span class="o">=</span> <span class="n">ClippedValueFunctionLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<h3>現在のポリシーを含むサンプルデータ</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="n">rewards</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">164</span> <span class="n">actions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="lineno">165</span> <span class="n">done</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
<span class="lineno">166</span> <span class="n">obs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">84</span><span class="p">,</span> <span class="mi">84</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
<span class="lineno">167</span> <span class="n">log_pis</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">168</span> <span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">169</span>
<span class="lineno">170</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">worker_steps</span></code>
各労働者からのサンプル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p><code class="highlight"><span></span><span class="bp">self</span><span class="o">.</span><span class="n">obs</span></code>
各ワーカーからの最後の観測値を追跡します。これは、モデルが次のアクションをサンプリングするための入力です</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">175</span> <span class="n">obs</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.680865em;vertical-align:-0.250305em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqm" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3567071428571427em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">O</span><span class="mord mathnormal mtight" style="">L</span><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14329285714285717em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.250305em;"><span></span></span></span></span></span></span></span></span></span></span></span>各ワーカーのサンプルアクション。これはサイズの配列を返します <code class="highlight"><span></span><span class="n">n_workers</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="n">pi</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">obs_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">))</span>
<span class="lineno">179</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="lineno">180</span> <span class="n">a</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
<span class="lineno">181</span> <span class="n">actions</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="lineno">182</span> <span class="n">log_pis</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">a</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>各ワーカーでサンプルアクションを実行</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="k">for</span> <span class="n">w</span><span class="p">,</span> <span class="n">worker</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">):</span>
<span class="lineno">186</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">send</span><span class="p">((</span><span class="s2">&quot;step&quot;</span><span class="p">,</span> <span class="n">actions</span><span class="p">[</span><span class="n">w</span><span class="p">,</span> <span class="n">t</span><span class="p">]))</span>
<span class="lineno">187</span>
<span class="lineno">188</span> <span class="k">for</span> <span class="n">w</span><span class="p">,</span> <span class="n">worker</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>アクションを実行した後に結果を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">[</span><span class="n">w</span><span class="p">],</span> <span class="n">rewards</span><span class="p">[</span><span class="n">w</span><span class="p">,</span> <span class="n">t</span><span class="p">],</span> <span class="n">done</span><span class="p">[</span><span class="n">w</span><span class="p">,</span> <span class="n">t</span><span class="p">],</span> <span class="n">info</span> <span class="o">=</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">recv</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>エピソードの情報を集めましょう。<code class="highlight"><span></span><span class="n">Game</span></code>
エピソードが終了したときに入手できます。これには報酬総額やエピソードの長さが含まれます。仕組みを確認してみましょう。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="k">if</span> <span class="n">info</span><span class="p">:</span>
<span class="lineno">196</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;reward&#39;</span><span class="p">,</span> <span class="n">info</span><span class="p">[</span><span class="s1">&#39;reward&#39;</span><span class="p">])</span>
<span class="lineno">197</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;length&#39;</span><span class="p">,</span> <span class="n">info</span><span class="p">[</span><span class="s1">&#39;length&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>最後のステップの後に値を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">obs_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">))</span>
<span class="lineno">201</span> <span class="n">values</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>利点を計算</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">advantages</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gae</span><span class="p">(</span><span class="n">done</span><span class="p">,</span> <span class="n">rewards</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">samples</span> <span class="o">=</span> <span class="p">{</span>
<span class="lineno">208</span> <span class="s1">&#39;obs&#39;</span><span class="p">:</span> <span class="n">obs</span><span class="p">,</span>
<span class="lineno">209</span> <span class="s1">&#39;actions&#39;</span><span class="p">:</span> <span class="n">actions</span><span class="p">,</span>
<span class="lineno">210</span> <span class="s1">&#39;values&#39;</span><span class="p">:</span> <span class="n">values</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span>
<span class="lineno">211</span> <span class="s1">&#39;log_pis&#39;</span><span class="p">:</span> <span class="n">log_pis</span><span class="p">,</span>
<span class="lineno">212</span> <span class="s1">&#39;advantages&#39;</span><span class="p">:</span> <span class="n">advantages</span>
<span class="lineno">213</span> <span class="p">}</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p><code class="highlight"><span></span><span class="p">[</span><span class="n">workers</span><span class="p">,</span> <span class="n">time_step</span><span class="p">]</span></code>
サンプルは現在テーブルにあるので、トレーニング用に平らにする必要があります</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">samples_flat</span> <span class="o">=</span> <span class="p">{}</span>
<span class="lineno">218</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">samples</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">219</span> <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:])</span>
<span class="lineno">220</span> <span class="k">if</span> <span class="n">k</span> <span class="o">==</span> <span class="s1">&#39;obs&#39;</span><span class="p">:</span>
<span class="lineno">221</span> <span class="n">samples_flat</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">obs_to_torch</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
<span class="lineno">222</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">223</span> <span class="n">samples_flat</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">224</span>
<span class="lineno">225</span> <span class="k">return</span> <span class="n">samples_flat</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<h3>サンプルに基づいてモデルをトレーニングする</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">samples</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>エポック数が多いほど学習は速くなりますが、少し不安定になります。つまり、エピソードの平均報酬は時間の経過とともに単調に増加しません。クリッピング範囲を狭くすることで解決する可能性があります。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epochs</span><span class="p">()):</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>各エポックのシャッフル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="n">indexes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randperm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>各ミニバッチ用</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">242</span> <span class="k">for</span> <span class="n">start</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mini_batch_size</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>ミニバッチを入手</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">244</span> <span class="n">end</span> <span class="o">=</span> <span class="n">start</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">mini_batch_size</span>
<span class="lineno">245</span> <span class="n">mini_batch_indexes</span> <span class="o">=</span> <span class="n">indexes</span><span class="p">[</span><span class="n">start</span><span class="p">:</span> <span class="n">end</span><span class="p">]</span>
<span class="lineno">246</span> <span class="n">mini_batch</span> <span class="o">=</span> <span class="p">{}</span>
<span class="lineno">247</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">samples</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">248</span> <span class="n">mini_batch</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="p">[</span><span class="n">mini_batch_indexes</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>列車</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">251</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_calc_loss</span><span class="p">(</span><span class="n">mini_batch</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>学習率を設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="k">for</span> <span class="n">pg</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
<span class="lineno">255</span> <span class="n">pg</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>以前に計算したグラデーションをゼロにします</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>勾配の計算</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">259</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>クリップグラデーション</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">261</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>グラデーションに基づいてパラメータを更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<h4>アドバンテージ関数の正規化</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span> <span class="nd">@staticmethod</span>
<span class="lineno">266</span> <span class="k">def</span> <span class="nf">_normalize</span><span class="p">(</span><span class="n">adv</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">268</span> <span class="k">return</span> <span class="p">(</span><span class="n">adv</span> <span class="o">-</span> <span class="n">adv</span><span class="o">.</span><span class="n">mean</span><span class="p">())</span> <span class="o">/</span> <span class="p">(</span><span class="n">adv</span><span class="o">.</span><span class="n">std</span><span class="p">()</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<h3>総損失の計算</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">270</span> <span class="k">def</span> <span class="nf">_calc_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">samples</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.00773em;">R</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.00773em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>からサンプリングされたリターン <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.680865em;vertical-align:-0.250305em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqm" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3567071428571427em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">O</span><span class="mord mathnormal mtight" style="">L</span><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14329285714285717em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.250305em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span> <span class="n">sampled_return</span> <span class="o">=</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;values&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;advantages&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.9701099999999999em;vertical-align:-0.15em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8201099999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">ˉ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.7954779999999997em;vertical-align:-0.6477389999999998em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.1477389999999998em;"><span style="top:-2.527261em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">σ</span><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqj" style=""><span class="mord accent mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.29634285714285713em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span style="top:-2.9523300000000003em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord mtight" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.485em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord accent mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.29634285714285713em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span style="top:-2.9523300000000003em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord mtight" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mbin mtight"></span><span class="mord mathnormal mtight">μ</span><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqj" style=""><span class="mord accent mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.29634285714285713em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span><span style="top:-2.9523300000000003em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord mtight" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.6477389999999998em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0967699999999998em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord accent" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.680865em;vertical-align:-0.250305em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqm" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3567071428571427em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">O</span><span class="mord mathnormal mtight" style="">L</span><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14329285714285717em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.250305em;"><span></span></span></span></span></span></span></span></span></span></span></span>利点はどこからサンプリングされているのか。の計算については、<a href="#main">下記のメインクラスのサンプリング関数を参照してください</a><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0967699999999998em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9467699999999999em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mathnormal">A</span></span><span style="top:-3.25233em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.11110999999999999em;"><span class="mord">^</span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">sampled_normalized_advantage</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_normalize</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;advantages&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord coloredeq eqm" style=""><span class="mord mathnormal" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord mathnormal" style="">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>サンプリングされた観測値はモデルに入力され、取得されます<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.664392em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqm" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>。観測値は状態として扱います</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">286</span> <span class="n">pi</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;obs&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqm" style=""><span class="mord mathnormal" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord mathnormal" style="">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord mathnormal" style="">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>アクションは以下からサンプリングされます <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.680865em;vertical-align:-0.250305em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqm" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3567071428571427em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">O</span><span class="mord mathnormal mtight" style="">L</span><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14329285714285717em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.250305em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">289</span> <span class="n">log_pi</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;actions&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>保険契約損失の計算</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="n">policy_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_loss</span><span class="p">(</span><span class="n">log_pi</span><span class="p">,</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;log_pis&#39;</span><span class="p">],</span> <span class="n">sampled_normalized_advantage</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_range</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>エントロピーボーナスの計算</p>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">EB</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.80002em;vertical-align:-0.65002em;"></span><span class="mord mathbb">E</span><span class="mopen"><span class="delimsizing size2">[</span></span><span class="mord mathnormal" style="margin-right:0.05764em;">S</span><span class="mopen"><span class="delimsizing size1">[</span></span><span class="mord"><span class="mord coloredeq eqm" style=""><span class="mord mathnormal" style="margin-right:0.03588em">π</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose"><span class="delimsizing size1">]</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mclose"><span class="delimsizing size2">]</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">298</span> <span class="n">entropy_bonus</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">entropy</span><span class="p">()</span>
<span class="lineno">299</span> <span class="n">entropy_bonus</span> <span class="o">=</span> <span class="n">entropy_bonus</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>値関数損失の計算</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">302</span> <span class="n">value_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_loss</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">samples</span><span class="p">[</span><span class="s1">&#39;values&#39;</span><span class="p">],</span> <span class="n">sampled_return</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_range</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.07153em;">C</span><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.07847em;">I</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">P</span><span class="mbin mtight">+</span><span class="mord mathnormal mtight" style="margin-right:0.22222em;">V</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">F</span><span class="mbin mtight">+</span><span class="mord mathnormal mtight" style="margin-right:0.05017em;">EB</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.07153em;">C</span><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.07847em;">I</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">P</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">c</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.22222em;">V</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">F</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.0913309999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">c</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathcal">L</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">EB</span></span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">307</span> <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">policy_loss</span>
<span class="lineno">308</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_loss_coef</span><span class="p">()</span> <span class="o">*</span> <span class="n">value_loss</span>
<span class="lineno">309</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">entropy_bonus_coef</span><span class="p">()</span> <span class="o">*</span> <span class="n">entropy_bonus</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>監視用</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span> <span class="n">approx_kl_divergence</span> <span class="o">=</span> <span class="mf">.5</span> <span class="o">*</span> <span class="p">((</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;log_pis&#39;</span><span class="p">]</span> <span class="o">-</span> <span class="n">log_pi</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>トラッカーに追加</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">315</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="s1">&#39;policy_reward&#39;</span><span class="p">:</span> <span class="o">-</span><span class="n">policy_loss</span><span class="p">,</span>
<span class="lineno">316</span> <span class="s1">&#39;value_loss&#39;</span><span class="p">:</span> <span class="n">value_loss</span><span class="p">,</span>
<span class="lineno">317</span> <span class="s1">&#39;entropy_bonus&#39;</span><span class="p">:</span> <span class="n">entropy_bonus</span><span class="p">,</span>
<span class="lineno">318</span> <span class="s1">&#39;kl_div&#39;</span><span class="p">:</span> <span class="n">approx_kl_divergence</span><span class="p">,</span>
<span class="lineno">319</span> <span class="s1">&#39;clip_fraction&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">ppo_loss</span><span class="o">.</span><span class="n">clip_fraction</span><span class="p">})</span>
<span class="lineno">320</span>
<span class="lineno">321</span> <span class="k">return</span> <span class="n">loss</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<h3>トレーニングループを実行</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">323</span> <span class="k">def</span> <span class="nf">run_training_loop</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>最後の 100 話の情報</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">329</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_queue</span><span class="p">(</span><span class="s1">&#39;reward&#39;</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">330</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_queue</span><span class="p">(</span><span class="s1">&#39;length&#39;</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">331</span>
<span class="lineno">332</span> <span class="k">for</span> <span class="n">update</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">updates</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>現行ポリシーのサンプル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">334</span> <span class="n">samples</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>モデルのトレーニング</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">337</span> <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">samples</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>追跡指標を保存します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">340</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>画面に定期的に新しい行を追加してください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">342</span> <span class="k">if</span> <span class="p">(</span><span class="n">update</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">1_000</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">343</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<h3>破壊</h3>
<p>労働者を止めろ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">345</span> <span class="k">def</span> <span class="nf">destroy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">350</span> <span class="k">for</span> <span class="n">worker</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">:</span>
<span class="lineno">351</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">send</span><span class="p">((</span><span class="s2">&quot;close&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">354</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>実験を作成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">356</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;ppo&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>コンフィギュレーション</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">358</span> <span class="n">configs</span> <span class="o">=</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>更新回数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">360</span> <span class="s1">&#39;updates&#39;</span><span class="p">:</span> <span class="mi">10000</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>⚙️ サンプルデータを使用してモデルをトレーニングするエポックの数。これは実験の実行中に変更できます。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">363</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="n">IntDynamicHyperParam</span><span class="p">(</span><span class="mi">8</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<p>ワーカープロセスの数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">365</span> <span class="s1">&#39;n_workers&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>1 回の更新で各プロセスで実行するステップの数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">367</span> <span class="s1">&#39;worker_steps&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>ミニバッチ数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">369</span> <span class="s1">&#39;batches&#39;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<p>⚙️ 価値損失係数。これは実験の実行中に変更できます。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">372</span> <span class="s1">&#39;value_loss_coef&#39;</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">(</span><span class="mf">0.5</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>⚙️ エントロピーボーナス係数。これは実験の実行中に変更できます。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">375</span> <span class="s1">&#39;entropy_bonus_coef&#39;</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">(</span><span class="mf">0.01</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>⚙️ クリップレンジ。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">377</span> <span class="s1">&#39;clip_range&#39;</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">(</span><span class="mf">0.1</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>テストの実行中にこれを変更できます。⚙️ 学習率。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">380</span> <span class="s1">&#39;learning_rate&#39;</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1e-3</span><span class="p">)),</span>
<span class="lineno">381</span> <span class="p">}</span>
<span class="lineno">382</span>
<span class="lineno">383</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>トレーナーを初期化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">386</span> <span class="n">m</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="o">**</span><span class="n">configs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>実験の実行と監視</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">389</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">390</span> <span class="n">m</span><span class="o">.</span><span class="n">run_training_loop</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<p>労働者を止めろ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">392</span> <span class="n">m</span><span class="o">.</span><span class="n">destroy</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<h2>実行してください</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">396</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="lineno">397</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>