mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			587 lines
		
	
	
		
			41 KiB
		
	
	
	
		
			HTML
		
	
	
	
	
	
			
		
		
	
	
			587 lines
		
	
	
		
			41 KiB
		
	
	
	
		
			HTML
		
	
	
	
	
	
| <!DOCTYPE html>
 | ||
| <html lang="en">
 | ||
| <head>
 | ||
|     <meta http-equiv="content-type" content="text/html;charset=utf-8"/>
 | ||
|     <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
 | ||
|     <meta name="description" content="A set of PyTorch implementations/tutorials of popular gradient descent based optimizers. Currently includes Adam, AMSGrad and RAdam optimizers."/>
 | ||
| 
 | ||
|     <meta name="twitter:card" content="summary"/>
 | ||
|     <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
 | ||
|     <meta name="twitter:title" content="Optimizers"/>
 | ||
|     <meta name="twitter:description" content="A set of PyTorch implementations/tutorials of popular gradient descent based optimizers. Currently includes Adam, AMSGrad and RAdam optimizers."/>
 | ||
|     <meta name="twitter:site" content="@labmlai"/>
 | ||
|     <meta name="twitter:creator" content="@labmlai"/>
 | ||
| 
 | ||
|     <meta property="og:url" content="https://nn.labml.ai/optimizers/index.html"/>
 | ||
|     <meta property="og:title" content="Optimizers"/>
 | ||
|     <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
 | ||
|     <meta property="og:site_name" content="Optimizers"/>
 | ||
|     <meta property="og:type" content="object"/>
 | ||
|     <meta property="og:title" content="Optimizers"/>
 | ||
|     <meta property="og:description" content="A set of PyTorch implementations/tutorials of popular gradient descent based optimizers. Currently includes Adam, AMSGrad and RAdam optimizers."/>
 | ||
| 
 | ||
|     <title>Optimizers</title>
 | ||
|     <link rel="shortcut icon" href="/icon.png"/>
 | ||
|     <link rel="stylesheet" href="../pylit.css?v=1">
 | ||
|     <link rel="canonical" href="https://nn.labml.ai/optimizers/index.html"/>
 | ||
|     <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
 | ||
| 
 | ||
|     <!-- Global site tag (gtag.js) - Google Analytics -->
 | ||
|     <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
 | ||
|     <script>
 | ||
|         window.dataLayer = window.dataLayer || [];
 | ||
| 
 | ||
|         function gtag() {
 | ||
|             dataLayer.push(arguments);
 | ||
|         }
 | ||
| 
 | ||
|         gtag('js', new Date());
 | ||
| 
 | ||
|         gtag('config', 'G-4V3HC8HBLH');
 | ||
|     </script>
 | ||
| </head>
 | ||
| <body>
 | ||
| <div id='container'>
 | ||
|     <div id="background"></div>
 | ||
|     <div class='section'>
 | ||
|         <div class='docs'>
 | ||
|             <p>
 | ||
|                 <a class="parent" href="/">home</a>
 | ||
|                 <a class="parent" href="index.html">optimizers</a>
 | ||
|             </p>
 | ||
|             <p>
 | ||
|                 <a href="https://github.com/sponsors/labmlai" target="_blank">
 | ||
|                     <img alt="Sponsor"
 | ||
|                          src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
 | ||
|                          style="max-width:100%;"/></a>
 | ||
|                 <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
 | ||
|                     <img alt="Github"
 | ||
|                          src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
 | ||
|                          style="max-width:100%;"/></a>
 | ||
|                 <a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
 | ||
|                     <img alt="Twitter"
 | ||
|                          src="https://img.shields.io/twitter/follow/labmlai?style=social"
 | ||
|                          style="max-width:100%;"/></a>
 | ||
|             </p>
 | ||
|             <p>
 | ||
|                 <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/__init__.py" target="_blank">
 | ||
|                     View code on Github</a>
 | ||
|             </p>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-0'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-0'>#</a>
 | ||
|             </div>
 | ||
|             <h1>Optimizers</h1>
 | ||
| <h2>Optimizer Implementations</h2>
 | ||
| <ul><li><a href="adam.html">Adam Optimizer</a> </li>
 | ||
| <li><a href="amsgrad.html">AMSGrad Optimizer</a> </li>
 | ||
| <li><a href="adam_warmup.html">Adam Optimizer with warmup</a> </li>
 | ||
| <li><a href="noam.html">Noam Optimizer</a> </li>
 | ||
| <li><a href="radam.html">Rectified Adam Optimizer</a> </li>
 | ||
| <li><a href="ada_belief.html">AdaBelief Optimizer</a></li></ul>
 | ||
| <p>This <a href="mnist_experiment.html">MNIST example</a> uses these optimizers.</p>
 | ||
| <h2>Generic Adaptive Optimizer Base class and Weight Decay</h2>
 | ||
| <p>This file defines a common base class for <em>Adam</em> and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.</p>
 | ||
| <p>We also define a special class for L2 weight decay, so that we don't have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.</p>
 | ||
| <p>Here are some concepts on PyTorch optimizers:</p>
 | ||
| <h3>Parameter groups</h3>
 | ||
| <p>PyTorch optimizers group parameters into sets called groups. Each group can have its own hyper-parameters like learning rates.</p>
 | ||
| <p>In most common cases there will be only one group. This is when you initialize your optimizer with,</p>
 | ||
| <pre  class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span></code></pre>
 | ||
| <p>You can define multiple parameter groups when initializing the optimizer:</p>
 | ||
| <pre  class="highlight lang-python"><code><span></span><span class="n">Optimizer</span><span class="p">([{</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">model1</span><span class="o">.</span><span class="n">parameters</span><span class="p">()},</span> <span class="p">{</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">model2</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="s1">'lr'</span><span class="p">:</span> <span class="mi">2</span><span class="p">}])</span></code></pre>
 | ||
| <p>Here we pass a list of groups. Each group is a dictionary with its parameters under the key 'params'. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.</p>
 | ||
| <p>You can access (and even change) these groups, and their hyper-parameters with <code  class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span></code>
 | ||
| . Most learning rate schedule implementations I've come across do access this and change 'lr'.</p>
 | ||
| <h3>States</h3>
 | ||
| <p>Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary <code  class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">state</span></code>
 | ||
| . This is where the optimizer maintains things like exponential averages.</p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">62</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span>
 | ||
| <span class="lineno">63</span>
 | ||
| <span class="lineno">64</span><span class="kn">import</span> <span class="nn">torch</span>
 | ||
| <span class="lineno">65</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
 | ||
| <span class="lineno">66</span><span class="kn">from</span> <span class="nn">torch.optim.optimizer</span> <span class="kn">import</span> <span class="n">Optimizer</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-1'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-1'>#</a>
 | ||
|             </div>
 | ||
|             <h2>Base class for <em>Adam</em> and extensions</h2>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">69</span><span class="k">class</span> <span class="nc">GenericAdaptiveOptimizer</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-2'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-2'>#</a>
 | ||
|             </div>
 | ||
|             <h3>Initialize</h3>
 | ||
| <ul><li><code  class="highlight"><span></span><span class="n">params</span></code>
 | ||
|  is the collection of parameters or set of parameter groups. </li>
 | ||
| <li><code  class="highlight"><span></span><span class="n">defaults</span></code>
 | ||
|  a dictionary of default hyper-parameters </li>
 | ||
| <li><code  class="highlight"><span></span><span class="n">lr</span></code>
 | ||
|  is the learning rate, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span></span> </li>
 | ||
| <li><code  class="highlight"><span></span><span class="n">betas</span></code>
 | ||
|  is the tuple <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> </li>
 | ||
| <li><code  class="highlight"><span></span><span class="n">eps</span></code>
 | ||
|  is <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">ϵ</span></span></span></span></span></li></ul>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">74</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-3'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-3'>#</a>
 | ||
|             </div>
 | ||
|             <p>Check the hyper-parameters </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">86</span>        <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">lr</span><span class="p">:</span>
 | ||
| <span class="lineno">87</span>            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid learning rate: </span><span class="si">{</span><span class="n">lr</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
 | ||
| <span class="lineno">88</span>        <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">eps</span><span class="p">:</span>
 | ||
| <span class="lineno">89</span>            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid epsilon value: </span><span class="si">{</span><span class="n">eps</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
 | ||
| <span class="lineno">90</span>        <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o"><</span> <span class="mf">1.0</span><span class="p">:</span>
 | ||
| <span class="lineno">91</span>            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid beta parameter at index 0: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
 | ||
| <span class="lineno">92</span>        <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o"><</span> <span class="mf">1.0</span><span class="p">:</span>
 | ||
| <span class="lineno">93</span>            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid beta parameter at index 1: </span><span class="si">{</span><span class="n">betas</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-4'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-4'>#</a>
 | ||
|             </div>
 | ||
|             <p>Add the hyper-parameters to the defaults </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">96</span>        <span class="n">defaults</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">))</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-5'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-5'>#</a>
 | ||
|             </div>
 | ||
|             <p>Initialize the PyTorch optimizer. This will create parameter groups with the default hyper-parameters </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">99</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-6'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-6'>#</a>
 | ||
|             </div>
 | ||
|             <h3>Initialize state for a given parameter tensor</h3>
 | ||
| <p>This should be overridden with code to initialize <code  class="highlight"><span></span><span class="n">state</span></code>
 | ||
|  for parameters <code  class="highlight"><span></span><span class="n">param</span></code>
 | ||
| . <code  class="highlight"><span></span><span class="n">group</span></code>
 | ||
|  is the parameter group dictionary to which <code  class="highlight"><span></span><span class="n">param</span></code>
 | ||
|  belongs.</p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">101</span>    <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">param</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-7'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-7'>#</a>
 | ||
|             </div>
 | ||
|             
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">108</span>        <span class="k">pass</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-8'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-8'>#</a>
 | ||
|             </div>
 | ||
|             <h3>Take optimizer step on a parameter tensor</h3>
 | ||
| <p>This should be overridden and take the optimization step on <code  class="highlight"><span></span><span class="n">param</span></code>
 | ||
|  tensor <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span></span>, where <code  class="highlight"><span></span><span class="n">grad</span></code>
 | ||
|  is the gradient for that parameter, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>, <code  class="highlight"><span></span><span class="n">state</span></code>
 | ||
|  is the optimizer state dictionary for that parameter, and <code  class="highlight"><span></span><span class="n">group</span></code>
 | ||
|  is the parameter group dictionary <code  class="highlight"><span></span><span class="n">param</span></code>
 | ||
|  belongs to.</p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">110</span>    <span class="k">def</span> <span class="nf">step_param</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-9'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-9'>#</a>
 | ||
|             </div>
 | ||
|             
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">119</span>        <span class="k">pass</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-10'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-10'>#</a>
 | ||
|             </div>
 | ||
|             <h3>Optimizer step</h3>
 | ||
| <p>We have created a template method that does the common stuff every <em>Adam</em> based optimizer needs.</p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">121</span>    <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
 | ||
| <span class="lineno">122</span>    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">closure</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-11'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-11'>#</a>
 | ||
|             </div>
 | ||
|             <p>Calculate loss.</p>
 | ||
| <p>🤔 I'm not sure when you need this. I guess it's if you define a function that calculates the loss, does <code  class="highlight"><span></span><span class="n">loss</span><span class="o">.</span><span class="n">backward</span></code>
 | ||
|  and return the loss, instead of calling it on your own you could pass it to <code  class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span></code>
 | ||
| . 🤷♂️ </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">133</span>        <span class="n">loss</span> <span class="o">=</span> <span class="kc">None</span>
 | ||
| <span class="lineno">134</span>        <span class="k">if</span> <span class="n">closure</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
 | ||
| <span class="lineno">135</span>            <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">enable_grad</span><span class="p">():</span>
 | ||
| <span class="lineno">136</span>                <span class="n">loss</span> <span class="o">=</span> <span class="n">closure</span><span class="p">()</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-12'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-12'>#</a>
 | ||
|             </div>
 | ||
|             <p>Iterate through the parameter groups </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">139</span>        <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-13'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-13'>#</a>
 | ||
|             </div>
 | ||
|             <p>Iterate through the parameters in the parameter group </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">141</span>            <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]:</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-14'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-14'>#</a>
 | ||
|             </div>
 | ||
|             <p>Skip if the parameter has no gradient </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">143</span>                <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
 | ||
| <span class="lineno">144</span>                    <span class="k">continue</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-15'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-15'>#</a>
 | ||
|             </div>
 | ||
|             <p>Get the gradient tensor </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">146</span>                <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-16'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-16'>#</a>
 | ||
|             </div>
 | ||
|             <p>We don't handle sparse gradients </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">148</span>                <span class="k">if</span> <span class="n">grad</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">:</span>
 | ||
| <span class="lineno">149</span>                    <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">'GenericAdaptiveOptimizer does not support sparse gradients,'</span>
 | ||
| <span class="lineno">150</span>                                       <span class="s1">' please consider SparseAdam instead'</span><span class="p">)</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-17'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-17'>#</a>
 | ||
|             </div>
 | ||
|             <p>Get the state for the parameter </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">153</span>                <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="n">param</span><span class="p">]</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-18'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-18'>#</a>
 | ||
|             </div>
 | ||
|             <p>Initialize the state if state is uninitialized </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">156</span>                <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
 | ||
| <span class="lineno">157</span>                    <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-19'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-19'>#</a>
 | ||
|             </div>
 | ||
|             <p>Take the optimization step on the parameter </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">160</span>                <span class="bp">self</span><span class="o">.</span><span class="n">step_param</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">param</span><span class="p">)</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-20'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-20'>#</a>
 | ||
|             </div>
 | ||
|             <p>Return the loss, calculated from closure </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">163</span>        <span class="k">return</span> <span class="n">loss</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-21'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-21'>#</a>
 | ||
|             </div>
 | ||
|             <h2>L2 Weight decay</h2>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">166</span><span class="k">class</span> <span class="nc">WeightDecay</span><span class="p">:</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-22'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-22'>#</a>
 | ||
|             </div>
 | ||
|             <h3>Initialize weight decay</h3>
 | ||
| <ul><li><code  class="highlight"><span></span><span class="n">weight_decay</span></code>
 | ||
|  is the decay coefficient </li>
 | ||
| <li><code  class="highlight"><span></span><span class="n">weight_decouple</span></code>
 | ||
|  is a flag indicating whether to add the weight decay to the gradient or directly decay from the parameter. If added to the gradient it will go through the normal optimizer update. </li>
 | ||
| <li><code  class="highlight"><span></span><span class="n">absolute</span></code>
 | ||
|  this flag indicates whether the weight decay coefficient is absolute. This is applicable when the decay is performed directly on the parameter. If this is false the actual decay is <code  class="highlight"><span></span><span class="n">weight_decay</span></code>
 | ||
|  </li>
 | ||
| <li><code  class="highlight"><span></span><span class="n">learning_rate</span></code>
 | ||
| .</li></ul>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">171</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">weight_decouple</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">absolute</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-23'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-23'>#</a>
 | ||
|             </div>
 | ||
|             <p>Check hyper-parameters </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">184</span>        <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o"><=</span> <span class="n">weight_decay</span><span class="p">:</span>
 | ||
| <span class="lineno">185</span>            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid weight_decay value: </span><span class="si">{</span><span class="n">weight_decay</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
 | ||
| <span class="lineno">186</span>
 | ||
| <span class="lineno">187</span>        <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span> <span class="o">=</span> <span class="n">absolute</span>
 | ||
| <span class="lineno">188</span>        <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span> <span class="o">=</span> <span class="n">weight_decouple</span>
 | ||
| <span class="lineno">189</span>        <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-24'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-24'>#</a>
 | ||
|             </div>
 | ||
|             <p> Return defaults for parameter groups</p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">191</span>    <span class="k">def</span> <span class="nf">defaults</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-25'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-25'>#</a>
 | ||
|             </div>
 | ||
|             
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">195</span>        <span class="k">return</span> <span class="nb">dict</span><span class="p">(</span><span class="n">weight_decay</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">)</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-26'>
 | ||
|         <div class='docs doc-strings'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-26'>#</a>
 | ||
|             </div>
 | ||
|             <h3>Perform weight decay and return the gradient</h3>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">197</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">,</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">]):</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-27'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-27'>#</a>
 | ||
|             </div>
 | ||
|             <p>If we are doing the decay on the parameter directly </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">203</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decouple</span><span class="p">:</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-28'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-28'>#</a>
 | ||
|             </div>
 | ||
|             <p>If the weight decay coefficient is absolute </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">205</span>            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute</span><span class="p">:</span>
 | ||
| <span class="lineno">206</span>                <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">'weight_decay'</span><span class="p">])</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-29'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-29'>#</a>
 | ||
|             </div>
 | ||
|             <p>Otherwise, </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">208</span>            <span class="k">else</span><span class="p">:</span>
 | ||
| <span class="lineno">209</span>                <span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">'lr'</span><span class="p">]</span> <span class="o">*</span> <span class="n">group</span><span class="p">[</span><span class="s1">'weight_decay'</span><span class="p">])</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-30'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-30'>#</a>
 | ||
|             </div>
 | ||
|             <p>Return the unmodified gradient </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">211</span>            <span class="k">return</span> <span class="n">grad</span>
 | ||
| <span class="lineno">212</span>        <span class="k">else</span><span class="p">:</span>
 | ||
| <span class="lineno">213</span>            <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">'weight_decay'</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='section' id='section-31'>
 | ||
|         <div class='docs'>
 | ||
|             <div class='section-link'>
 | ||
|                 <a href='#section-31'>#</a>
 | ||
|             </div>
 | ||
|             <p>Add the weight decay to the gradient and return the modified gradient </p>
 | ||
| 
 | ||
|         </div>
 | ||
|         <div class='code'>
 | ||
|             <div class="highlight"><pre><span class="lineno">215</span>                <span class="k">return</span> <span class="n">grad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">group</span><span class="p">[</span><span class="s1">'weight_decay'</span><span class="p">])</span>
 | ||
| <span class="lineno">216</span>            <span class="k">else</span><span class="p">:</span>
 | ||
| <span class="lineno">217</span>                <span class="k">return</span> <span class="n">grad</span></pre></div>
 | ||
|         </div>
 | ||
|     </div>
 | ||
|     <div class='footer'>
 | ||
|         <a href="https://papers.labml.ai">Trending Research Papers</a>
 | ||
|         <a href="https://labml.ai">labml.ai</a>
 | ||
|     </div>
 | ||
| </div>
 | ||
| <script src=../interactive.js?v=1"></script>
 | ||
| <script>
 | ||
|     function handleImages() {
 | ||
|         var images = document.querySelectorAll('p>img')
 | ||
| 
 | ||
|         for (var i = 0; i < images.length; ++i) {
 | ||
|             handleImage(images[i])
 | ||
|         }
 | ||
|     }
 | ||
| 
 | ||
|     function handleImage(img) {
 | ||
|         img.parentElement.style.textAlign = 'center'
 | ||
| 
 | ||
|         var modal = document.createElement('div')
 | ||
|         modal.id = 'modal'
 | ||
| 
 | ||
|         var modalContent = document.createElement('div')
 | ||
|         modal.appendChild(modalContent)
 | ||
| 
 | ||
|         var modalImage = document.createElement('img')
 | ||
|         modalContent.appendChild(modalImage)
 | ||
| 
 | ||
|         var span = document.createElement('span')
 | ||
|         span.classList.add('close')
 | ||
|         span.textContent = 'x'
 | ||
|         modal.appendChild(span)
 | ||
| 
 | ||
|         img.onclick = function () {
 | ||
|             console.log('clicked')
 | ||
|             document.body.appendChild(modal)
 | ||
|             modalImage.src = img.src
 | ||
|         }
 | ||
| 
 | ||
|         span.onclick = function () {
 | ||
|             document.body.removeChild(modal)
 | ||
|         }
 | ||
|     }
 | ||
| 
 | ||
|     handleImages()
 | ||
| </script>
 | ||
| </body>
 | ||
| </html> | 
