Files
Varuna Jayasiri 2038b11d29 ja translation
2023-05-10 17:00:29 -04:00

583 lines
42 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="ja">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="人気のある勾配降下ベースのオプティマイザーのPyTorch実装/チュートリアルのセット。現在、Adam、Masgrad、および Adamのオプティマイザーが含まれています。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="オプティマイザー"/>
<meta name="twitter:description" content="人気のある勾配降下ベースのオプティマイザーのPyTorch実装/チュートリアルのセット。現在、Adam、Masgrad、および Adamのオプティマイザーが含まれています。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/optimizers/index.html"/>
<meta property="og:title" content="オプティマイザー"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="オプティマイザー"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="オプティマイザー"/>
<meta property="og:description" content="人気のある勾配降下ベースのオプティマイザーのPyTorch実装/チュートリアルのセット。現在、Adam、Masgrad、および Adamのオプティマイザーが含まれています。"/>
<title>オプティマイザー</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/optimizers/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">optimizers</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>オプティマイザー</h1>
<h2>オプティマイザーの実装</h2>
<ul><li><a href="adam.html">アダム・オプティマイザー</a></li>
<li><a href="amsgrad.html">マスグラードオプティマイザー</a></li>
<li><a href="adam_warmup.html">ウォームアップ機能付き Adam オプティマイザー</a></li>
<li><a href="noam.html">ノームオプティマイザー</a></li>
<li><a href="radam.html">修正されたアダムオプティマイザー</a></li>
<li><a href="ada_belief.html">アダブリリーフオプティマイザー</a></li></ul>
<p>この <a href="mnist_experiment.html">MNIST の例では</a>、これらのオプティマイザーを使用しています。</p>
<h2>汎用アダプティブオプティマイザー基本クラスとウェイトディケイ</h2>
<p>このファイルは、<em>Adam</em> の共通基本クラスとその拡張を定義しています。基本クラスは、再利用が可能なため、最小限のコードで他のオプティマイザを実装するのに役立ちます</p>
<p>また、L2の重み減衰用の特別なクラスを定義しているので、各オプティマイザー内に実装する必要がなく、オプティマイザーを変更せずにL1のような他の重み減衰にも簡単に拡張できます。</p>
<p>PyTorch オプティマイザの概念は次のとおりです。</p>
<h3>パラメータグループ</h3>
<p>PyTorch オプティマイザーは、パラメーターをグループと呼ばれるセットにグループ化します。各グループには、学習率などの独自のハイパーパラメータを設定できます</p>
<p>たいていの場合、グループが 1 つしかありません。このとき、オプティマイザを次のように初期化します</p>
<pre class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span></code></pre>
<p>オプティマイザを初期化するときに、複数のパラメータグループを定義できます。</p>
<pre class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">([{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">model1</span><span class="o">.</span><span class="n">parameters</span><span class="p">()},</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">model2</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="s1">&#39;lr&#39;</span><span class="p">:</span> <span class="mi">2</span><span class="p">}])</span></code></pre>
<p>ここにグループのリストを渡します。各グループは辞書で、パラメータは 'params' です。任意のハイパーパラメータも指定します。ハイパーパラメータが定義されていない場合は、デフォルトでオプティマイザレベルのデフォルトになります</p>
<p>を使用してこれらのグループとそのハイパーパラメータにアクセスしたり、変更したりすることができます。<code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span></code>
私が出会ったほとんどの学習率スケジュールの実装は、これにアクセスして「lr」を変更します</p>
<h3></h3>
<p>オプティマイザーは、各パラメーター (テンソル) の状態 (辞書) を辞書に保持します。<code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">state</span></code>
ここで、オプティマイザーは指数平均などを管理します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span>
<span class="lineno">63</span>
<span class="lineno">64</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">65</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">66</span><span class="kn">from</span> <span class="nn">torch.optim.optimizer</span> <span class="kn">import</span> <span class="n">Optimizer</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2><em>Adam</em> と拡張機能の基底クラス</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span><span class="k">class</span> <span class="nc">GenericAdaptiveOptimizer</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h3>[初期化]</h3>
<ul><li><code class="highlight"><span></span><span class="n">params</span></code>
パラメータのコレクションまたはパラメータグループのセットです。</li>
<li><code class="highlight"><span></span><span class="n">defaults</span></code>
デフォルトのハイパーパラメータの辞書</li>
<li><code class="highlight"><span></span><span class="n">lr</span></code>
は学習率 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">betas</span></code>
はタプルです <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></li>
</ul><li><code class="highlight"><span></span><span class="n">eps</span></code>
<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">ϵ</span></span></span></span></span></li>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>ハイパーパラメータを確認</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">lr</span><span class="p">:</span>
<span class="lineno">87</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid learning rate: </span><span class="si">{</span><span class="n">lr</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">88</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">eps</span><span class="p">:</span>
<span class="lineno">89</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid epsilon value: </span><span class="si">{</span><span class="n">eps</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">90</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="lineno">91</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid beta parameter at index 0: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">92</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="lineno">93</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid beta parameter at index 1: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>ハイパーパラメータをデフォルトに追加</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">defaults</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>PyTorch オプティマイザーを初期化します。これにより、デフォルトのハイパーパラメータを使用してパラメータグループが作成されます</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<h3>与えられたパラメータテンソルの状態を初期化</h3>
<p><code class="highlight"><span></span><span class="n">state</span></code>
これをオーバーライドしてパラメータを初期化するコードを使うべきです。<code class="highlight"><span></span><span class="n">param</span></code>
<code class="highlight"><span></span><span class="n">group</span></code>
<code class="highlight"><span></span><span class="n">param</span></code>
が属するパラメータグループディクショナリです。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">param</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h3>パラメーターテンソルでオプティマイザーステップを実行する</h3>
<p>これをオーバーライドして、<code class="highlight"><span></span><span class="n">param</span></code>
テンソルで最適化ステップを実行する必要があります。ここで<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span></span><code class="highlight"><span></span><span class="n">grad</span></code>
、はそのパラメーターの勾配、はそのパラメーターのオプティマイザー状態ディクショナリ、<code class="highlight"><span></span><span class="n">state</span></code>
<code class="highlight"><span></span><span class="n">group</span></code>
はディクショナリが属するパラメーターグループです。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> <code class="highlight"><span></span><span class="n">param</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">def</span> <span class="nf">step_param</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<h3>オプティマイザーステップ</h3>
<p><em>すべてのAdamベースのオプティマイザーが必要とする一般的な処理を行うテンプレートメソッドを作成しました</em></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">122</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">closure</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>損失を計算します。</p>
<p>🤔 いつこれが必要なのかわかりません。自分で呼び出すのではなく、<code class="highlight"><span></span><span class="n">loss</span><span class="o">.</span><span class="n">backward</span></code>
損失を計算して損失を出して返す関数を定義すれば、その関数を渡せると思います<code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span></code>
。🤷‍♂️</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">loss</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">134</span> <span class="k">if</span> <span class="n">closure</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">135</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">enable_grad</span><span class="p">():</span>
<span class="lineno">136</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">closure</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>パラメータグループを繰り返し処理する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>パラメータグループ内のパラメータを繰り返し処理します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>パラメータにグラデーションがない場合はスキップ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">144</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>勾配テンソルを取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>スパースグラデーションは扱いません</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">if</span> <span class="n">grad</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">:</span>
<span class="lineno">149</span> <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;GenericAdaptiveOptimizer does not support sparse gradients,&#39;</span>
<span class="lineno">150</span> <span class="s1">&#39; please consider SparseAdam instead&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>パラメータの状態を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="n">param</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>状態が初期化されていない場合は状態を初期化します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">157</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>パラメータの最適化手順を実行してください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_param</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>決済から計算した損失額を返金</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">return</span> <span class="n">loss</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<h2>L2 ウェイト・ディケイ</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span><span class="k">class</span> <span class="nc">WeightDecay</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<h3>体重減衰を初期化</h3>
<ul><li><code class="highlight"><span></span><span class="n">weight_decay</span></code>
は減衰係数</li>
<li><code class="highlight"><span></span><span class="n">weight_decouple</span></code>
グラデーションにウェイトディケイを追加するか、パラメータから直接ディケイを加えるかを示すフラグです。グラデーションに追加すると、通常のオプティマイザーの更新が行われます</li>
<li><code class="highlight"><span></span><span class="n">absolute</span></code>
このフラグは重量減衰係数が絶対値かどうかを示します。これは、ディケイをパラメータに直接適用する場合に適用できます。これが false の場合、実際の減衰は <code class="highlight"><span></span><span class="n">weight_decay</span></code>
</li>
</ul><li><code class="highlight"><span></span><span class="n">learning_rate</span></code>
</li>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">weight_decouple</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">absolute</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>ハイパーパラメータをチェック</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">weight_decay</span><span class="p">:</span>
<span class="lineno">185</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid weight_decay value: </span><span class="si">{</span><span class="n">weight_decay</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">186</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span> <span class="o">=</span> <span class="n">absolute</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span> <span class="o">=</span> <span class="n">weight_decouple</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>パラメータグループのデフォルト値を返す</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="k">def</span> <span class="nf">defaults</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="k">return</span> <span class="nb">dict</span><span class="p">(</span><span class="n">weight_decay</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<h3>ウェイトディケイを実行してグラデーションを戻す</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">,</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>パラメータで直接ディケイを行う場合</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>重量減衰係数が絶対値の場合</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span><span class="p">:</span>
<span class="lineno">206</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>それ以外の場合は、</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">209</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">]</span> <span class="o">*</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>変更されていないグラデーションを返す</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">return</span> <span class="n">grad</span>
<span class="lineno">212</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">213</span> <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>グラデーションにウェイトディケイを追加し、変更したグラデーションを返します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span> <span class="k">return</span> <span class="n">grad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">])</span>
<span class="lineno">216</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">217</span> <span class="k">return</span> <span class="n">grad</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>