Files
Varuna Jayasiri 2038b11d29 ja translation
2023-05-10 17:00:29 -04:00

1176 lines
80 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="ja">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="アタリ・ブレイクアウトによるDQN実験の実施"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="アタリ・ブレイクアウトによるDQN実験"/>
<meta name="twitter:description" content="アタリ・ブレイクアウトによるDQN実験の実施"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/rl/dqn/experiment.html"/>
<meta property="og:title" content="アタリ・ブレイクアウトによるDQN実験"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="アタリ・ブレイクアウトによるDQN実験"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="アタリ・ブレイクアウトによるDQN実験"/>
<meta property="og:description" content="アタリ・ブレイクアウトによるDQN実験の実施"/>
<title>アタリ・ブレイクアウトによるDQN実験</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/rl/dqn/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">rl</a>
<a class="parent" href="index.html">dqn</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rl/dqn/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>アタリ・ブレイクアウトによるDQN実験</h1>
<p>この実験では、ディープQネットワークDQNにOpenAI Gymでアタリブレイクアウトゲームをプレイするようにトレーニングします。<a href="../game.html">ゲーム環境を複数のプロセスで実行して効率的にサンプリングします</a></p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="lineno">16</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">17</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">monit</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml.internal.configs.dynamic_hyperparam</span> <span class="kn">import</span> <span class="n">FloatDynamicHyperParam</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_helpers.schedule</span> <span class="kn">import</span> <span class="n">Piecewise</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.rl.dqn</span> <span class="kn">import</span> <span class="n">QFuncLoss</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.rl.dqn.model</span> <span class="kn">import</span> <span class="n">Model</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.rl.dqn.replay_buffer</span> <span class="kn">import</span> <span class="n">ReplayBuffer</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.rl.game</span> <span class="kn">import</span> <span class="n">Worker</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p>デバイスを選択</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
<span class="lineno">28</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cuda:0&quot;</span><span class="p">)</span>
<span class="lineno">29</span><span class="k">else</span><span class="p">:</span>
<span class="lineno">30</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p><code class="highlight"><span></span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">]</span></code>
観測値をからにスケーリング <code class="highlight"><span></span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="k">def</span> <span class="nf">obs_to_torch</span><span class="p">(</span><span class="n">obs</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">obs</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<h2>トレーナー</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span><span class="k">class</span> <span class="nc">Trainer</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">44</span> <span class="n">updates</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">45</span> <span class="n">n_workers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">worker_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mini_batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">46</span> <span class="n">update_target_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">47</span> <span class="n">learning_rate</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">,</span>
<span class="lineno">48</span> <span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>労働者の数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">=</span> <span class="n">n_workers</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>更新のたびにサンプリングされるステップ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span> <span class="o">=</span> <span class="n">worker_steps</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>トレーニングの反復回数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_epochs</span> <span class="o">=</span> <span class="n">epochs</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>更新回数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">updates</span> <span class="o">=</span> <span class="n">updates</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>トレーニング用ミニバッチのサイズ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="bp">self</span><span class="o">.</span><span class="n">mini_batch_size</span> <span class="o">=</span> <span class="n">mini_batch_size</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>250 回の更新ごとにターゲットネットワークを更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_target_model</span> <span class="o">=</span> <span class="n">update_target_model</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>学習率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="n">learning_rate</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>更新機能としての探索</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_coefficient</span> <span class="o">=</span> <span class="n">Piecewise</span><span class="p">(</span>
<span class="lineno">69</span> <span class="p">[</span>
<span class="lineno">70</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span>
<span class="lineno">71</span> <span class="p">(</span><span class="mi">25_000</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">),</span>
<span class="lineno">72</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">updates</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">)</span>
<span class="lineno">73</span> <span class="p">],</span> <span class="n">outside_value</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span></span></span></span></span></span>更新機能としての再生バッファ用</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">prioritized_replay_beta</span> <span class="o">=</span> <span class="n">Piecewise</span><span class="p">(</span>
<span class="lineno">77</span> <span class="p">[</span>
<span class="lineno">78</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">),</span>
<span class="lineno">79</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">updates</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">80</span> <span class="p">],</span> <span class="n">outside_value</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>リプレイバッファは<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0.6</span></span></span></span></span>.再生バッファの容量は 2 の累乗でなければなりません</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">replay_buffer</span> <span class="o">=</span> <span class="n">ReplayBuffer</span><span class="p">(</span><span class="mi">2</span> <span class="o">**</span> <span class="mi">14</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>サンプリングとトレーニング用のモデル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>取得する対象モデル <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.088326em;vertical-align:-0.276864em;"></span><span class="mord"><span class="mord mathnormal" style="color:orange">Q</span><span class="mopen coloredeq eqa" style="">(</span><span class="mord coloredeq eqa" style=""><span class="mord mathnormal" style="">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.751892em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqa" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style="color:orange"><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.811462em;"><span style="top:-2.4231360000000004em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span><span style="top:-3.1031310000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.276864em;"><span></span></span></span></span></span></span><span class="mclose coloredeq eqa" style="">)</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>ワーカーを作成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="bp">self</span><span class="o">.</span><span class="n">workers</span> <span class="o">=</span> <span class="p">[</span><span class="n">Worker</span><span class="p">(</span><span class="mi">47</span> <span class="o">+</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>観測用のテンソルを初期化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">84</span><span class="p">,</span> <span class="mi">84</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>ワーカーをリセット</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="k">for</span> <span class="n">worker</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">:</span>
<span class="lineno">98</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">send</span><span class="p">((</span><span class="s2">&quot;reset&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>初期観測値を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">worker</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">):</span>
<span class="lineno">102</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">recv</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>損失関数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">QFuncLoss</span><span class="p">(</span><span class="mf">0.99</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>オプティマイザー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">2.5e-4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<h4><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span>-貪欲なサンプリング</h4>
<p>アクションをサンプリングするときは、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span>-greedy ストラテジーを使用します。つまり、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span>確率のある貪欲なアクションを実行し、確率のあるランダムなアクションを実行します。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span>と呼びます<code class="highlight"><span></span><span class="n">exploration_coefficient</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span> <span class="k">def</span> <span class="nf">_sample_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">exploration_coefficient</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>サンプリングにはグラデーションは必要ありません</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Q値が最も高いアクションをサンプリングします。これは貪欲な行動です</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">greedy_action</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">q_value</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>サンプルとアクションを均一に</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">random_action</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">q_value</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">greedy_action</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">q_value</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>欲張りアクションとランダムアクションのどちらを選ぶか</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="n">is_choose_rand</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">greedy_action</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">q_value</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">exploration_coefficient</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>以下に基づいてアクションを選択してください <code class="highlight"><span></span><span class="n">is_choose_rand</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">is_choose_rand</span><span class="p">,</span> <span class="n">random_action</span><span class="p">,</span> <span class="n">greedy_action</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<h3>サンプルデータ</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">exploration_coefficient</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>これにはグラデーションは必要ありません</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>[サンプル] <code class="highlight"><span></span><span class="n">worker_steps</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>現在の観測値の Q_value を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">q_value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">obs_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>サンプルアクション</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">actions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sample_action</span><span class="p">(</span><span class="n">q_value</span><span class="p">,</span> <span class="n">exploration_coefficient</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>各ワーカーでサンプルアクションを実行</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="k">for</span> <span class="n">w</span><span class="p">,</span> <span class="n">worker</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">):</span>
<span class="lineno">143</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">send</span><span class="p">((</span><span class="s2">&quot;step&quot;</span><span class="p">,</span> <span class="n">actions</span><span class="p">[</span><span class="n">w</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>各作業者から情報を収集する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="k">for</span> <span class="n">w</span><span class="p">,</span> <span class="n">worker</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>アクションを実行した後に結果を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">next_obs</span><span class="p">,</span> <span class="n">reward</span><span class="p">,</span> <span class="n">done</span><span class="p">,</span> <span class="n">info</span> <span class="o">=</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">recv</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>再生バッファにトランジションを追加</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="bp">self</span><span class="o">.</span><span class="n">replay_buffer</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">[</span><span class="n">w</span><span class="p">],</span> <span class="n">actions</span><span class="p">[</span><span class="n">w</span><span class="p">],</span> <span class="n">reward</span><span class="p">,</span> <span class="n">next_obs</span><span class="p">,</span> <span class="n">done</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>エピソード情報を更新します。エピソードが終了した場合に利用できるエピソード情報を収集します。これには、合計報酬とエピソードの長さが含まれます。仕組みを確認してみてください。<code class="highlight"><span></span><span class="n">Game</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">if</span> <span class="n">info</span><span class="p">:</span>
<span class="lineno">158</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;reward&#39;</span><span class="p">,</span> <span class="n">info</span><span class="p">[</span><span class="s1">&#39;reward&#39;</span><span class="p">])</span>
<span class="lineno">159</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;length&#39;</span><span class="p">,</span> <span class="n">info</span><span class="p">[</span><span class="s1">&#39;length&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>現在の観測値を更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="bp">self</span><span class="o">.</span><span class="n">obs</span><span class="p">[</span><span class="n">w</span><span class="p">]</span> <span class="o">=</span> <span class="n">next_obs</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<h3>モデルのトレーニング</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_epochs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>プライオリティ・リプレイ・バッファからのサンプル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="n">samples</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">replay_buffer</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mini_batch_size</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>予測された Q 値の取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">q_value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">obs_to_torch</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;obs&#39;</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p><a href="index.html">二重Q学習の次の状態のQ値を取得します</a>。これらの場合、グラデーションは伝播しないはずです</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>取得 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.001892em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="color:cyan">Q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.751892em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"></span></span></span></span></span></span></span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style="color:cyan"><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="n">double_q_value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">obs_to_torch</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;next_obs&#39;</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>取得 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.088326em;vertical-align:-0.276864em;"></span><span class="mord"><span class="mord mathnormal" style="color:orange">Q</span><span class="mopen coloredeq eqa" style="">(</span><span class="mord coloredeq eqa" style=""><span class="mord mathnormal" style="">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.751892em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span><span class="mpunct coloredeq eqa" style="">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style="color:orange"><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.811462em;"><span style="top:-2.4231360000000004em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span><span style="top:-3.1031310000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.276864em;"><span></span></span></span></span></span></span><span class="mclose coloredeq eqa" style="">)</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">180</span> <span class="n">target_q_value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_model</span><span class="p">(</span><span class="n">obs_to_torch</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;next_obs&#39;</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>時差 (TD) 誤差<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span></span>、および損失を計算します。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal">L</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">td_errors</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">q_value</span><span class="p">,</span>
<span class="lineno">184</span> <span class="n">q_value</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;action&#39;</span><span class="p">]),</span>
<span class="lineno">185</span> <span class="n">double_q_value</span><span class="p">,</span> <span class="n">target_q_value</span><span class="p">,</span>
<span class="lineno">186</span> <span class="n">q_value</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;done&#39;</span><span class="p">]),</span>
<span class="lineno">187</span> <span class="n">q_value</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;reward&#39;</span><span class="p">]),</span>
<span class="lineno">188</span> <span class="n">q_value</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;weights&#39;</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>再生バッファの優先度を計算 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"></span><span class="mord"><span class="mord coloredeq eqh" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">new_priorities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">td_errors</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="o">+</span> <span class="mf">1e-6</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>リプレイバッファの優先順位を更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="bp">self</span><span class="o">.</span><span class="n">replay_buffer</span><span class="o">.</span><span class="n">update_priorities</span><span class="p">(</span><span class="n">samples</span><span class="p">[</span><span class="s1">&#39;indexes&#39;</span><span class="p">],</span> <span class="n">new_priorities</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>学習率を設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">for</span> <span class="n">pg</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
<span class="lineno">197</span> <span class="n">pg</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>以前に計算したグラデーションをゼロにします</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>勾配の計算</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>クリップグラデーション</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>グラデーションに基づいてパラメータを更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<h3>トレーニングループを実行</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="k">def</span> <span class="nf">run_training_loop</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>最新100話の情報</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_queue</span><span class="p">(</span><span class="s1">&#39;reward&#39;</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">214</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_queue</span><span class="p">(</span><span class="s1">&#39;length&#39;</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>最初にターゲットネットワークにコピー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">())</span>
<span class="lineno">218</span>
<span class="lineno">219</span> <span class="k">for</span> <span class="n">update</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">updates</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span>、探査フラクション</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">exploration</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">exploration_coefficient</span><span class="p">(</span><span class="n">update</span><span class="p">)</span>
<span class="lineno">222</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;exploration&#39;</span><span class="p">,</span> <span class="n">exploration</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span></span></span></span></span></span>優先再生用</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="n">beta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prioritized_replay_beta</span><span class="p">(</span><span class="n">update</span><span class="p">)</span>
<span class="lineno">225</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;beta&#39;</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>現在のポリシーを含むサンプル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">exploration</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>バッファーがいっぱいになったらトレーニングを開始する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">231</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">replay_buffer</span><span class="o">.</span><span class="n">is_full</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>モデルのトレーニング</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">233</span> <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">beta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>ターゲットネットワークを定期的に更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">236</span> <span class="k">if</span> <span class="n">update</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_target_model</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">237</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p>追跡指標を保存します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">240</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>画面に定期的に新しい行を追加してください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">242</span> <span class="k">if</span> <span class="p">(</span><span class="n">update</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">1_000</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">243</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<h3>破壊</h3>
<p>労働者を止めろ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">245</span> <span class="k">def</span> <span class="nf">destroy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span> <span class="k">for</span> <span class="n">worker</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">workers</span><span class="p">:</span>
<span class="lineno">251</span> <span class="n">worker</span><span class="o">.</span><span class="n">child</span><span class="o">.</span><span class="n">send</span><span class="p">((</span><span class="s2">&quot;close&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>実験を作成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">256</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;dqn&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>コンフィギュレーション</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">259</span> <span class="n">configs</span> <span class="o">=</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>更新回数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">261</span> <span class="s1">&#39;updates&#39;</span><span class="p">:</span> <span class="mi">1_000_000</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>サンプルデータを使用してモデルをトレーニングするエポックの数。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>ワーカープロセスの数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span> <span class="s1">&#39;n_workers&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>1 回の更新で各プロセスで実行するステップの数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">267</span> <span class="s1">&#39;worker_steps&#39;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>ミニバッチサイズ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">269</span> <span class="s1">&#39;mini_batch_size&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>対象モデルの更新間隔</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">271</span> <span class="s1">&#39;update_target_model&#39;</span><span class="p">:</span> <span class="mi">250</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<p>学習率。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">273</span> <span class="s1">&#39;learning_rate&#39;</span><span class="p">:</span> <span class="n">FloatDynamicHyperParam</span><span class="p">(</span><span class="mf">1e-4</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1e-3</span><span class="p">)),</span>
<span class="lineno">274</span> <span class="p">}</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<p>コンフィギュレーション</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">277</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<p>トレーナーを初期化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">280</span> <span class="n">m</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="o">**</span><span class="n">configs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>実験の実行と監視</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">283</span> <span class="n">m</span><span class="o">.</span><span class="n">run_training_loop</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>労働者を止めろ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span> <span class="n">m</span><span class="o">.</span><span class="n">destroy</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<h2>実行してください</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">289</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="lineno">290</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>