Files
2024-06-21 19:35:22 +05:30

583 lines
41 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A set of PyTorch implementations/tutorials of popular gradient descent based optimizers. Currently includes Adam, AMSGrad and RAdam optimizers."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Optimizers"/>
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials of popular gradient descent based optimizers. Currently includes Adam, AMSGrad and RAdam optimizers."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/optimizers/index.html"/>
<meta property="og:title" content="Optimizers"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Optimizers"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Optimizers"/>
<meta property="og:description" content="A set of PyTorch implementations/tutorials of popular gradient descent based optimizers. Currently includes Adam, AMSGrad and RAdam optimizers."/>
<title>Optimizers</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/optimizers/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">optimizers</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Optimizers</h1>
<h2>Optimizer Implementations</h2>
<ul><li><a href="adam.html">Adam Optimizer</a> </li>
<li><a href="amsgrad.html">AMSGrad Optimizer</a> </li>
<li><a href="adam_warmup.html">Adam Optimizer with warmup</a> </li>
<li><a href="noam.html">Noam Optimizer</a> </li>
<li><a href="radam.html">Rectified Adam Optimizer</a> </li>
<li><a href="ada_belief.html">AdaBelief Optimizer</a> </li>
<li><a href="sophia.html">Sophia-G Optimizer</a></li></ul>
<p>This <a href="mnist_experiment.html">MNIST example</a> uses these optimizers.</p>
<h2>Generic Adaptive Optimizer Base class and Weight Decay</h2>
<p>This file defines a common base class for <em>Adam</em> and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.</p>
<p>We also define a special class for L2 weight decay, so that we don&#x27;t have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.</p>
<p>Here are some concepts on PyTorch optimizers:</p>
<h3>Parameter groups</h3>
<p>PyTorch optimizers group parameters into sets called groups. Each group can have its own hyper-parameters like learning rates.</p>
<p>In most common cases there will be only one group. This is when you initialize your optimizer with,</p>
<pre class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span></code></pre>
<p>You can define multiple parameter groups when initializing the optimizer:</p>
<pre class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">([{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">model1</span><span class="o">.</span><span class="n">parameters</span><span class="p">()},</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">model2</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="s1">&#39;lr&#39;</span><span class="p">:</span> <span class="mi">2</span><span class="p">}])</span></code></pre>
<p>Here we pass a list of groups. Each group is a dictionary with its parameters under the key &#x27;params&#x27;. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.</p>
<p>You can access (and even change) these groups, and their hyper-parameters with <code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span></code>
. Most learning rate schedule implementations I&#x27;ve come across do access this and change &#x27;lr&#x27;.</p>
<h3>States</h3>
<p>Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary <code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">state</span></code>
. This is where the optimizer maintains things like exponential averages.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span>
<span class="lineno">64</span>
<span class="lineno">65</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">66</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">67</span><span class="kn">from</span> <span class="nn">torch.optim.optimizer</span> <span class="kn">import</span> <span class="n">Optimizer</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Base class for <em>Adam</em> and extensions</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span><span class="k">class</span> <span class="nc">GenericAdaptiveOptimizer</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h3>Initialize</h3>
<ul><li><code class="highlight"><span></span><span class="n">params</span></code>
is the collection of parameters or set of parameter groups. </li>
<li><code class="highlight"><span></span><span class="n">defaults</span></code>
a dictionary of default hyper-parameters </li>
<li><code class="highlight"><span></span><span class="n">lr</span></code>
is the learning rate, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">betas</span></code>
is the tuple <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">eps</span></code>
is <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">ϵ</span></span></span></span></span></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Check the hyper-parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">lr</span><span class="p">:</span>
<span class="lineno">88</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid learning rate: </span><span class="si">{</span><span class="n">lr</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">89</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">eps</span><span class="p">:</span>
<span class="lineno">90</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid epsilon value: </span><span class="si">{</span><span class="n">eps</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">91</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="lineno">92</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid beta parameter at index 0: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">93</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="lineno">94</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid beta parameter at index 1: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Add the hyper-parameters to the defaults </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">defaults</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Initialize the PyTorch optimizer. This will create parameter groups with the default hyper-parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<h3>Initialize state for a given parameter tensor</h3>
<p>This should be overridden with code to initialize <code class="highlight"><span></span><span class="n">state</span></code>
for parameters <code class="highlight"><span></span><span class="n">param</span></code>
. <code class="highlight"><span></span><span class="n">group</span></code>
is the parameter group dictionary to which <code class="highlight"><span></span><span class="n">param</span></code>
belongs.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">102</span> <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">param</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h3>Take optimizer step on a parameter tensor</h3>
<p>This should be overridden and take the optimization step on <code class="highlight"><span></span><span class="n">param</span></code>
tensor <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span></span>, where <code class="highlight"><span></span><span class="n">grad</span></code>
is the gradient for that parameter, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>, <code class="highlight"><span></span><span class="n">state</span></code>
is the optimizer state dictionary for that parameter, and <code class="highlight"><span></span><span class="n">group</span></code>
is the parameter group dictionary <code class="highlight"><span></span><span class="n">param</span></code>
belongs to.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="k">def</span> <span class="nf">step_param</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<h3>Optimizer step</h3>
<p>We have created a template method that does the common stuff every <em>Adam</em> based optimizer needs.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">123</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">closure</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Calculate loss.</p>
<p>🤔 I&#x27;m not sure when you need this. I guess it&#x27;s if you define a function that calculates the loss, does <code class="highlight"><span></span><span class="n">loss</span><span class="o">.</span><span class="n">backward</span></code>
and return the loss, instead of calling it on your own you could pass it to <code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span></code>
. 🤷‍♂️ </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">loss</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">135</span> <span class="k">if</span> <span class="n">closure</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">136</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">enable_grad</span><span class="p">():</span>
<span class="lineno">137</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">closure</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Iterate through the parameter groups </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Iterate through the parameters in the parameter group </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Skip if the parameter has no gradient </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">145</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Get the gradient tensor </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>We don&#x27;t handle sparse gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="k">if</span> <span class="n">grad</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">:</span>
<span class="lineno">150</span> <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;GenericAdaptiveOptimizer does not support sparse gradients,&#39;</span>
<span class="lineno">151</span> <span class="s1">&#39; please consider SparseAdam instead&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Get the state for the parameter </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="n">param</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Initialize the state if state is uninitialized </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">158</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Take the optimization step on the parameter </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_param</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Return the loss, calculated from closure </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">return</span> <span class="n">loss</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<h2>L2 Weight decay</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span><span class="k">class</span> <span class="nc">WeightDecay</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<h3>Initialize weight decay</h3>
<ul><li><code class="highlight"><span></span><span class="n">weight_decay</span></code>
is the decay coefficient </li>
<li><code class="highlight"><span></span><span class="n">weight_decouple</span></code>
is a flag indicating whether to add the weight decay to the gradient or directly decay from the parameter. If added to the gradient it will go through the normal optimizer update. </li>
<li><code class="highlight"><span></span><span class="n">absolute</span></code>
this flag indicates whether the weight decay coefficient is absolute. This is applicable when the decay is performed directly on the parameter. If this is false the actual decay is <code class="highlight"><span></span><span class="n">weight_decay</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">learning_rate</span></code>
.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">weight_decouple</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">absolute</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Check hyper-parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">weight_decay</span><span class="p">:</span>
<span class="lineno">186</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid weight_decay value: </span><span class="si">{</span><span class="n">weight_decay</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">187</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span> <span class="o">=</span> <span class="n">absolute</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span> <span class="o">=</span> <span class="n">weight_decouple</span>
<span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p> Return defaults for parameter groups</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">192</span> <span class="k">def</span> <span class="nf">defaults</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">return</span> <span class="nb">dict</span><span class="p">(</span><span class="n">weight_decay</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<h3>Perform weight decay and return the gradient</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">,</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>If we are doing the decay on the parameter directly </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>If the weight decay coefficient is absolute </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span><span class="p">:</span>
<span class="lineno">207</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Otherwise, </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">210</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">]</span> <span class="o">*</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Return the unmodified gradient </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="k">return</span> <span class="n">grad</span>
<span class="lineno">213</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">214</span> <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Add the weight decay to the gradient and return the modified gradient </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">216</span> <span class="k">return</span> <span class="n">grad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">])</span>
<span class="lineno">217</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">218</span> <span class="k">return</span> <span class="n">grad</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>