Files
Varuna Jayasiri 1c14551a19 zh
2023-02-28 08:40:22 +05:30

583 lines
40 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="一组流行的基于梯度下降的优化器的 PyTorch 实现/教程。目前包括 Adam、AmsGrad 和 RadAM 优化器。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="优化器"/>
<meta name="twitter:description" content="一组流行的基于梯度下降的优化器的 PyTorch 实现/教程。目前包括 Adam、AmsGrad 和 RadAM 优化器。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/optimizers/index.html"/>
<meta property="og:title" content="优化器"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="优化器"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="优化器"/>
<meta property="og:description" content="一组流行的基于梯度下降的优化器的 PyTorch 实现/教程。目前包括 Adam、AmsGrad 和 RadAM 优化器。"/>
<title>优化器</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/optimizers/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">optimizers</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>优化器</h1>
<h2>优化器实现</h2>
<ul><li><a href="adam.html">亚当优化器</a></li>
<li><a href="amsgrad.html">amsGrad 优化器</a></li>
<li><a href="adam_warmup.html">Adam Optimizer 带热身</a></li>
<li><a href="noam.html">Noam 优化器</a></li>
<li><a href="radam.html">纠正亚当优化器</a></li>
<li><a href="ada_belief.html">adaBelief 优化器</a></li></ul>
<p><a href="mnist_experiment.html">MNIST 示例</a>使用了这些优化器。</p>
<h2>通用自适应优化器基类和权重衰减</h2>
<p>这个文件定义了 <em>Adam</em> 的通用基类及其扩展。由于可重用性,基类有助于以最少的代码实现其他优化器。</p>
<p>我们还为 L2 权重衰减定义了一个特殊的类,这样我们就不必在每个优化器中实现它,并且可以在不更改优化器的情况下轻松扩展到其他权重衰减,例如 L1。</p>
<p>以下是关于 PyTorch 优化器的一些概念:</p>
<h3>参数组</h3>
<p>PyTorch 优化器将参数分组到名为组的集合中。每个组可以有自己的超参数,例如学习率。</p>
<p>在大多数情况下,只有一组。这是你使用初始化优化器的时候,</p>
<pre class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span></code></pre>
<p>在初始化优化器时,可以定义多个参数组:</p>
<pre class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">([{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">model1</span><span class="o">.</span><span class="n">parameters</span><span class="p">()},</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">model2</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="s1">&#39;lr&#39;</span><span class="p">:</span> <span class="mi">2</span><span class="p">}])</span></code></pre>
<p>在这里,我们传递一个组列表。每个组都是一个字典,其参数位于键 “params” 下。您也可以指定任何超参数。如果未定义 hyper 参数,它们将默认为优化程序级别的默认值。</p>
<p>您可以使用访问(甚至更改)这些组及其超参数<code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span></code>
。我遇到的大多数学习率计划实现都访问了这个并更改了 “lr”。</p>
<h3>各州</h3>
<p>Optimizer 在字典中维护每个参数(张量)的状态(字典)<code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">state</span></code>
。这是优化器维护指数平均值之类的东西的地方。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span>
<span class="lineno">63</span>
<span class="lineno">64</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">65</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">66</span><span class="kn">from</span> <span class="nn">torch.optim.optimizer</span> <span class="kn">import</span> <span class="n">Optimizer</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2><em>Adam</em> 和扩展的基类</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span><span class="k">class</span> <span class="nc">GenericAdaptiveOptimizer</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h3>初始化</h3>
<ul><li><code class="highlight"><span></span><span class="n">params</span></code>
是参数的集合或一组参数组。</li>
<li><code class="highlight"><span></span><span class="n">defaults</span></code>
默认超参数的字典</li>
<li><code class="highlight"><span></span><span class="n">lr</span></code>
是学习率,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">betas</span></code>
是元组<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></li>
</ul><li><code class="highlight"><span></span><span class="n">eps</span></code>
<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">ϵ</span></span></span></span></span></li>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>检查超参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">lr</span><span class="p">:</span>
<span class="lineno">87</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid learning rate: </span><span class="si">{</span><span class="n">lr</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">88</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">eps</span><span class="p">:</span>
<span class="lineno">89</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid epsilon value: </span><span class="si">{</span><span class="n">eps</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">90</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="lineno">91</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid beta parameter at index 0: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">92</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="lineno">93</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid beta parameter at index 1: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>将超参数添加到默认值</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">defaults</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>初始化 PyTorch 优化器。这将使用默认的超参数创建参数组</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<h3>初始化给定参数张量的状态</h3>
<p>这应该被代码覆盖,以便初始<code class="highlight"><span></span><span class="n">state</span></code>
化参数<code class="highlight"><span></span><span class="n">param</span></code>
<code class="highlight"><span></span><span class="n">group</span></code>
是所<code class="highlight"><span></span><span class="n">param</span></code>
属的参数组字典。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">param</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h3>在参数张量上采取优化器步骤</h3>
<p>这应该被重写并对<code class="highlight"><span></span><span class="n">param</span></code>
张量采取优化步骤<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span></span>,其中<code class="highlight"><span></span><span class="n">grad</span></code>
是该参数的梯度<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><code class="highlight"><span></span><span class="n">state</span></code>
是该参数的优化器状态字典,<code class="highlight"><span></span><span class="n">group</span></code>
也是参数组字典<code class="highlight"><span></span><span class="n">param</span></code>
所属的。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">def</span> <span class="nf">step_param</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<h3>优化器步骤</h3>
<p>我们创建了一个模板方法,它可以完成每个基于 <em>Adam</em> 的优化器所需要的常用内容。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">122</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">closure</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>计算损失。</p>
<p>🤔 我不确定你什么时候需要这个。我想如果你定义一个函数来计算损失,做<code class="highlight"><span></span><span class="n">loss</span><span class="o">.</span><span class="n">backward</span></code>
和返回损失,而不是自己调用它,你可以传递给它<code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span></code>
。🤷‍♂️</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">loss</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">134</span> <span class="k">if</span> <span class="n">closure</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">135</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">enable_grad</span><span class="p">():</span>
<span class="lineno">136</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">closure</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>遍历参数组</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>遍历参数组中的参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>如果参数没有渐变,则跳过</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">144</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>获取梯度张量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>我们不处理稀疏渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">if</span> <span class="n">grad</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">:</span>
<span class="lineno">149</span> <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;GenericAdaptiveOptimizer does not support sparse gradients,&#39;</span>
<span class="lineno">150</span> <span class="s1">&#39; please consider SparseAdam instead&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>获取参数的状态</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="n">param</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>如果状态未初始化,则初始化状态</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">157</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>对参数采取优化步骤</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_param</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>返回从闭包计算得出的损失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">return</span> <span class="n">loss</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<h2>L2 重量衰减</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span><span class="k">class</span> <span class="nc">WeightDecay</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<h3>初始化权重衰减</h3>
<ul><li><code class="highlight"><span></span><span class="n">weight_decay</span></code>
是衰减系数</li>
<li><code class="highlight"><span></span><span class="n">weight_decouple</span></code>
是一个标志,指示是将权重衰减添加到梯度还是直接从参数中衰减。如果添加到渐变中,它将通过普通的优化器更新。</li>
<li><code class="highlight"><span></span><span class="n">absolute</span></code>
此标志指示权重衰减系数是否为绝对值。当直接对参数执行衰减时,这适用。如果此值为假,则实际衰减为<code class="highlight"><span></span><span class="n">weight_decay</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">learning_rate</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">weight_decouple</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">absolute</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>检查超参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">weight_decay</span><span class="p">:</span>
<span class="lineno">185</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid weight_decay value: </span><span class="si">{</span><span class="n">weight_decay</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">186</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span> <span class="o">=</span> <span class="n">absolute</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span> <span class="o">=</span> <span class="n">weight_decouple</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>返回参数组的默认值</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="k">def</span> <span class="nf">defaults</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="k">return</span> <span class="nb">dict</span><span class="p">(</span><span class="n">weight_decay</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<h3>执行权重衰减并返回梯度</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">,</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>如果我们直接对参数进行衰减</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>如果权重衰减系数为绝对值</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span><span class="p">:</span>
<span class="lineno">206</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>否则,</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">209</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">]</span> <span class="o">*</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>返回未修改的渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">return</span> <span class="n">grad</span>
<span class="lineno">212</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">213</span> <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>将权重衰减添加到渐变并返回修改后的渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span> <span class="k">return</span> <span class="n">grad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">])</span>
<span class="lineno">216</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">217</span> <span class="k">return</span> <span class="n">grad</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>