mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-20 21:21:17 +08:00
583 lines
40 KiB
HTML
583 lines
40 KiB
HTML
<!DOCTYPE html>
|
||
<html lang="zh">
|
||
<head>
|
||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||
<meta name="description" content="一组流行的基于梯度下降的优化器的 PyTorch 实现/教程。目前包括 Adam、AmsGrad 和 RadAM 优化器。"/>
|
||
|
||
<meta name="twitter:card" content="summary"/>
|
||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta name="twitter:title" content="优化器"/>
|
||
<meta name="twitter:description" content="一组流行的基于梯度下降的优化器的 PyTorch 实现/教程。目前包括 Adam、AmsGrad 和 RadAM 优化器。"/>
|
||
<meta name="twitter:site" content="@labmlai"/>
|
||
<meta name="twitter:creator" content="@labmlai"/>
|
||
|
||
<meta property="og:url" content="https://nn.labml.ai/optimizers/index.html"/>
|
||
<meta property="og:title" content="优化器"/>
|
||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta property="og:site_name" content="优化器"/>
|
||
<meta property="og:type" content="object"/>
|
||
<meta property="og:title" content="优化器"/>
|
||
<meta property="og:description" content="一组流行的基于梯度下降的优化器的 PyTorch 实现/教程。目前包括 Adam、AmsGrad 和 RadAM 优化器。"/>
|
||
|
||
<title>优化器</title>
|
||
<link rel="shortcut icon" href="/icon.png"/>
|
||
<link rel="stylesheet" href="../pylit.css?v=1">
|
||
<link rel="canonical" href="https://nn.labml.ai/optimizers/index.html"/>
|
||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||
|
||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||
<script>
|
||
window.dataLayer = window.dataLayer || [];
|
||
|
||
function gtag() {
|
||
dataLayer.push(arguments);
|
||
}
|
||
|
||
gtag('js', new Date());
|
||
|
||
gtag('config', 'G-4V3HC8HBLH');
|
||
</script>
|
||
</head>
|
||
<body>
|
||
<div id='container'>
|
||
<div id="background"></div>
|
||
<div class='section'>
|
||
<div class='docs'>
|
||
<p>
|
||
<a class="parent" href="/">home</a>
|
||
<a class="parent" href="index.html">optimizers</a>
|
||
</p>
|
||
<p>
|
||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
|
||
<img alt="Github"
|
||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||
style="max-width:100%;"/></a>
|
||
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
|
||
<img alt="Twitter"
|
||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||
style="max-width:100%;"/></a>
|
||
</p>
|
||
<p>
|
||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/__init__.py" target="_blank">
|
||
View code on Github</a>
|
||
</p>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-0'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-0'>#</a>
|
||
</div>
|
||
<h1>Optimizers</h1>
|
||
<h2>Optimizer Implementations</h2>
|
||
<ul><li><a href="adam.html">Adam Optimizer</a> </li>
|
||
<li><a href="amsgrad.html">AMSGrad Optimizer</a> </li>
|
||
<li><a href="adam_warmup.html">Adam Optimizer with warmup</a> </li>
|
||
<li><a href="noam.html">Noam Optimizer</a> </li>
|
||
<li><a href="radam.html">Rectified Adam Optimizer</a> </li>
|
||
<li><a href="ada_belief.html">AdaBelief Optimizer</a> </li>
|
||
<li><a href="sophia.html">Sophia-G Optimizer</a></li></ul>
|
||
<p>This <a href="mnist_experiment.html">MNIST example</a> uses these optimizers.</p>
|
||
<h2>Generic Adaptive Optimizer Base class and Weight Decay</h2>
|
||
<p>This file defines a common base class for <em>Adam</em> and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.</p>
|
||
<p>We also define a special class for L2 weight decay, so that we don't have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.</p>
|
||
<p>Here are some concepts on PyTorch optimizers:</p>
|
||
<h3>Parameter groups</h3>
|
||
<p>PyTorch optimizers group parameters into sets called groups. Each group can have its own hyper-parameters like learning rates.</p>
|
||
<p>In most common cases there will be only one group. This is when you initialize your optimizer with,</p>
|
||
<pre class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span></code></pre>
|
||
<p>You can define multiple parameter groups when initializing the optimizer:</p>
|
||
<pre class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">([{</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">model1</span><span class="o">.</span><span class="n">parameters</span><span class="p">()},</span> <span class="p">{</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">model2</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="s1">'lr'</span><span class="p">:</span> <span class="mi">2</span><span class="p">}])</span></code></pre>
|
||
<p>Here we pass a list of groups. Each group is a dictionary with its parameters under the key 'params'. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.</p>
|
||
<p>You can access (and even change) these groups, and their hyper-parameters with <code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span></code>
|
||
. Most learning rate schedule implementations I've come across do access this and change 'lr'.</p>
|
||
<h3>States</h3>
|
||
<p>Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary <code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">state</span></code>
|
||
. This is where the optimizer maintains things like exponential averages.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">63</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span>
|
||
<span class="lineno">64</span>
|
||
<span class="lineno">65</span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="lineno">66</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||
<span class="lineno">67</span><span class="kn">from</span> <span class="nn">torch.optim.optimizer</span> <span class="kn">import</span> <span class="n">Optimizer</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-1'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-1'>#</a>
|
||
</div>
|
||
<h2><em>Adam</em> 和扩展的基类</h2>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">70</span><span class="k">class</span> <span class="nc">GenericAdaptiveOptimizer</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-2'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-2'>#</a>
|
||
</div>
|
||
<h3>初始化</h3>
|
||
<ul><li><code class="highlight"><span></span><span class="n">params</span></code>
|
||
是参数的集合或一组参数组。</li>
|
||
<li><code class="highlight"><span></span><span class="n">defaults</span></code>
|
||
默认超参数的字典</li>
|
||
<li><code class="highlight"><span></span><span class="n">lr</span></code>
|
||
是学习率,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span></span></li>
|
||
<li><code class="highlight"><span></span><span class="n">betas</span></code>
|
||
是元组<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></li>
|
||
</ul><li><code class="highlight"><span></span><span class="n">eps</span></code>
|
||
是<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">ϵ</span></span></span></span></span></li>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-3'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-3'>#</a>
|
||
</div>
|
||
<p>检查超参数</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">87</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">lr</span><span class="p">:</span>
|
||
<span class="lineno">88</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid learning rate: </span><span class="si">{</span><span class="n">lr</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="lineno">89</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">eps</span><span class="p">:</span>
|
||
<span class="lineno">90</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid epsilon value: </span><span class="si">{</span><span class="n">eps</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="lineno">91</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o"><</span> <span class="mf">1.0</span><span class="p">:</span>
|
||
<span class="lineno">92</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid beta parameter at index 0: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="lineno">93</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o"><</span> <span class="mf">1.0</span><span class="p">:</span>
|
||
<span class="lineno">94</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid beta parameter at index 1: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-4'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-4'>#</a>
|
||
</div>
|
||
<p>将超参数添加到默认值</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">defaults</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">))</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-5'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-5'>#</a>
|
||
</div>
|
||
<p>初始化 PyTorch 优化器。这将使用默认的超参数创建参数组</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">100</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-6'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-6'>#</a>
|
||
</div>
|
||
<h3>初始化给定参数张量的状态</h3>
|
||
<p>这应该被代码覆盖,以便初始<code class="highlight"><span></span><span class="n">state</span></code>
|
||
化参数<code class="highlight"><span></span><span class="n">param</span></code>
|
||
。<code class="highlight"><span></span><span class="n">group</span></code>
|
||
是所<code class="highlight"><span></span><span class="n">param</span></code>
|
||
属的参数组字典。</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">102</span> <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">param</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-7'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-7'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">109</span> <span class="k">pass</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-8'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-8'>#</a>
|
||
</div>
|
||
<h3>在参数张量上采取优化器步骤</h3>
|
||
<p>这应该被重写并对<code class="highlight"><span></span><span class="n">param</span></code>
|
||
张量采取优化步骤<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span></span>,其中<code class="highlight"><span></span><span class="n">grad</span></code>
|
||
是该参数的梯度<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>,<code class="highlight"><span></span><span class="n">state</span></code>
|
||
是该参数的优化器状态字典,<code class="highlight"><span></span><span class="n">group</span></code>
|
||
也是参数组字典<code class="highlight"><span></span><span class="n">param</span></code>
|
||
所属的。</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">111</span> <span class="k">def</span> <span class="nf">step_param</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-9'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-9'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">pass</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-10'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-10'>#</a>
|
||
</div>
|
||
<h3>优化器步骤</h3>
|
||
<p>我们创建了一个模板方法,它可以完成每个基于 <em>Adam</em> 的优化器所需要的常用内容。</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">122</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
|
||
<span class="lineno">123</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">closure</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-11'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-11'>#</a>
|
||
</div>
|
||
<p>计算损失。</p>
|
||
<p>🤔 我不确定你什么时候需要这个。我想如果你定义一个函数来计算损失,做<code class="highlight"><span></span><span class="n">loss</span><span class="o">.</span><span class="n">backward</span></code>
|
||
和返回损失,而不是自己调用它,你可以传递给它<code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span></code>
|
||
。🤷♂️</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">loss</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="lineno">135</span> <span class="k">if</span> <span class="n">closure</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="lineno">136</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">enable_grad</span><span class="p">():</span>
|
||
<span class="lineno">137</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">closure</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-12'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-12'>#</a>
|
||
</div>
|
||
<p>遍历参数组</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">140</span> <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-13'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-13'>#</a>
|
||
</div>
|
||
<p>遍历参数组中的参数</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">142</span> <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-14'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-14'>#</a>
|
||
</div>
|
||
<p>如果参数没有渐变,则跳过</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="lineno">145</span> <span class="k">continue</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-15'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-15'>#</a>
|
||
</div>
|
||
<p>获取梯度张量</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-16'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-16'>#</a>
|
||
</div>
|
||
<p>我们不处理稀疏渐变</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">149</span> <span class="k">if</span> <span class="n">grad</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">:</span>
|
||
<span class="lineno">150</span> <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">'GenericAdaptiveOptimizer does not support sparse gradients,'</span>
|
||
<span class="lineno">151</span> <span class="s1">' please consider SparseAdam instead'</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-17'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-17'>#</a>
|
||
</div>
|
||
<p>获取参数的状态</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="n">param</span><span class="p">]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-18'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-18'>#</a>
|
||
</div>
|
||
<p>如果状态未初始化,则初始化状态</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="lineno">158</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-19'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-19'>#</a>
|
||
</div>
|
||
<p>对参数采取优化步骤</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">161</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_param</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-20'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-20'>#</a>
|
||
</div>
|
||
<p>返回从闭包计算得出的损失</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">return</span> <span class="n">loss</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-21'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-21'>#</a>
|
||
</div>
|
||
<h2>L2 重量衰减</h2>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">167</span><span class="k">class</span> <span class="nc">WeightDecay</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-22'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-22'>#</a>
|
||
</div>
|
||
<h3>初始化权重衰减</h3>
|
||
<ul><li><code class="highlight"><span></span><span class="n">weight_decay</span></code>
|
||
是衰减系数</li>
|
||
<li><code class="highlight"><span></span><span class="n">weight_decouple</span></code>
|
||
是一个标志,指示是将权重衰减添加到梯度还是直接从参数中衰减。如果添加到渐变中,它将通过普通的优化器更新。</li>
|
||
<li><code class="highlight"><span></span><span class="n">absolute</span></code>
|
||
此标志指示权重衰减系数是否为绝对值。当直接对参数执行衰减时,这适用。如果此值为假,则实际衰减为<code class="highlight"><span></span><span class="n">weight_decay</span></code>
|
||
</li>
|
||
<li><code class="highlight"><span></span><span class="n">learning_rate</span></code>
|
||
。</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">172</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">weight_decouple</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">absolute</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-23'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-23'>#</a>
|
||
</div>
|
||
<p>检查超参数</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">185</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">weight_decay</span><span class="p">:</span>
|
||
<span class="lineno">186</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid weight_decay value: </span><span class="si">{</span><span class="n">weight_decay</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="lineno">187</span>
|
||
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span> <span class="o">=</span> <span class="n">absolute</span>
|
||
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span> <span class="o">=</span> <span class="n">weight_decouple</span>
|
||
<span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-24'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-24'>#</a>
|
||
</div>
|
||
<p>返回参数组的默认值</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">192</span> <span class="k">def</span> <span class="nf">defaults</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-25'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-25'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">return</span> <span class="nb">dict</span><span class="p">(</span><span class="n">weight_decay</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-26'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-26'>#</a>
|
||
</div>
|
||
<h3>执行权重衰减并返回梯度</h3>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">,</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">]):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-27'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-27'>#</a>
|
||
</div>
|
||
<p>如果我们直接对参数进行衰减</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">204</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-28'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-28'>#</a>
|
||
</div>
|
||
<p>如果权重衰减系数为绝对值</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">206</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span><span class="p">:</span>
|
||
<span class="lineno">207</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">'weight_decay'</span><span class="p">])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-29'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-29'>#</a>
|
||
</div>
|
||
<p>否则,</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">209</span> <span class="k">else</span><span class="p">:</span>
|
||
<span class="lineno">210</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">'lr'</span><span class="p">]</span> <span class="o">*</span> <span class="n">group</span><span class="p">[</span><span class="s1">'weight_decay'</span><span class="p">])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-30'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-30'>#</a>
|
||
</div>
|
||
<p>返回未修改的渐变</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">212</span> <span class="k">return</span> <span class="n">grad</span>
|
||
<span class="lineno">213</span> <span class="k">else</span><span class="p">:</span>
|
||
<span class="lineno">214</span> <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">'weight_decay'</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-31'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-31'>#</a>
|
||
</div>
|
||
<p>将权重衰减添加到渐变并返回修改后的渐变</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">216</span> <span class="k">return</span> <span class="n">grad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">group</span><span class="p">[</span><span class="s1">'weight_decay'</span><span class="p">])</span>
|
||
<span class="lineno">217</span> <span class="k">else</span><span class="p">:</span>
|
||
<span class="lineno">218</span> <span class="k">return</span> <span class="n">grad</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='footer'>
|
||
<a href="https://labml.ai">labml.ai</a>
|
||
</div>
|
||
</div>
|
||
<script src=../interactive.js?v=1"></script>
|
||
<script>
|
||
function handleImages() {
|
||
var images = document.querySelectorAll('p>img')
|
||
|
||
for (var i = 0; i < images.length; ++i) {
|
||
handleImage(images[i])
|
||
}
|
||
}
|
||
|
||
function handleImage(img) {
|
||
img.parentElement.style.textAlign = 'center'
|
||
|
||
var modal = document.createElement('div')
|
||
modal.id = 'modal'
|
||
|
||
var modalContent = document.createElement('div')
|
||
modal.appendChild(modalContent)
|
||
|
||
var modalImage = document.createElement('img')
|
||
modalContent.appendChild(modalImage)
|
||
|
||
var span = document.createElement('span')
|
||
span.classList.add('close')
|
||
span.textContent = 'x'
|
||
modal.appendChild(span)
|
||
|
||
img.onclick = function () {
|
||
console.log('clicked')
|
||
document.body.appendChild(modal)
|
||
modalImage.src = img.src
|
||
}
|
||
|
||
span.onclick = function () {
|
||
document.body.removeChild(modal)
|
||
}
|
||
}
|
||
|
||
handleImages()
|
||
</script>
|
||
</body>
|
||
</html> |