Files
2024-06-21 19:35:22 +05:30

452 lines
37 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="这为优化器实现了一个可配置的模块。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="可配置的优化器模块"/>
<meta name="twitter:description" content="这为优化器实现了一个可配置的模块。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/optimizers/configs.html"/>
<meta property="og:title" content="可配置的优化器模块"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="可配置的优化器模块"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="可配置的优化器模块"/>
<meta property="og:description" content="这为优化器实现了一个可配置的模块。"/>
<title>可配置的优化器模块</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/optimizers/configs.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">optimizers</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/configs.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>可配置的优化器</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span>
<span class="lineno">11</span>
<span class="lineno">12</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">13</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span><span class="p">,</span> <span class="n">option</span><span class="p">,</span> <span class="n">meta_config</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers</span> <span class="kn">import</span> <span class="n">WeightDecay</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p><a id="OptimizerConfigs"></a></p>
<h2>优化器配置</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">OptimizerConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>优化器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">26</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>体重衰减</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span> <span class="n">weight_decay_obj</span><span class="p">:</span> <span class="n">WeightDecay</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>权重衰减是否解耦;即权重衰减不添加到梯度中</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span> <span class="n">weight_decouple</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>体重衰减</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>体重衰减是绝对的还是应该乘以学习速率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">36</span> <span class="n">weight_decay_absolute</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>adam 更新是否经过优化(不同的 epsilon</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">optimized_adam_update</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>要优化的参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">parameters</span><span class="p">:</span> <span class="nb">any</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>学习率<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">learning_rate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Adam<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> 的 Beta 值</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Epsilon<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">ϵ</span></span></span></span></span> 代表亚当</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-08</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>新加坡元的势头</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>是否使用 AmsGrad</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">amsgrad</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>预热优化器步骤数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">warmup</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2_000</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>优化器步长总数(余弦衰减)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">total_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1e10</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>是否在 AdaBeLief 中退化为新加坡元</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">degenerate_to_sgd</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>是否在 AdaBelief 中使用整改过的亚当</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">rectify</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Noam 优化器的模型嵌入大小</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span>
<span class="lineno">69</span>
<span class="lineno">70</span> <span class="n">rho</span><span class="p">:</span> <span class="nb">float</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">73</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">_primary</span><span class="o">=</span><span class="s1">&#39;optimizer&#39;</span><span class="p">)</span>
<span class="lineno">74</span>
<span class="lineno">75</span>
<span class="lineno">76</span><span class="n">meta_config</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">parameters</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span><span class="nd">@option</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">weight_decay_obj</span><span class="p">,</span> <span class="s1">&#39;L2&#39;</span><span class="p">)</span>
<span class="lineno">80</span><span class="k">def</span> <span class="nf">_weight_decay</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">OptimizerConfigs</span><span class="p">):</span>
<span class="lineno">81</span> <span class="k">return</span> <span class="n">WeightDecay</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">weight_decouple</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">weight_decay_absolute</span><span class="p">)</span>
<span class="lineno">82</span>
<span class="lineno">83</span>
<span class="lineno">84</span><span class="nd">@option</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">&#39;SGD&#39;</span><span class="p">)</span>
<span class="lineno">85</span><span class="k">def</span> <span class="nf">_sgd_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">OptimizerConfigs</span><span class="p">):</span>
<span class="lineno">86</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">parameters</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">momentum</span><span class="p">,</span>
<span class="lineno">87</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">)</span>
<span class="lineno">88</span>
<span class="lineno">89</span>
<span class="lineno">90</span><span class="nd">@option</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">&#39;Adam&#39;</span><span class="p">)</span>
<span class="lineno">91</span><span class="k">def</span> <span class="nf">_adam_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">OptimizerConfigs</span><span class="p">):</span>
<span class="lineno">92</span> <span class="k">if</span> <span class="n">c</span><span class="o">.</span><span class="n">amsgrad</span><span class="p">:</span>
<span class="lineno">93</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.amsgrad</span> <span class="kn">import</span> <span class="n">AMSGrad</span>
<span class="lineno">94</span> <span class="k">return</span> <span class="n">AMSGrad</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">parameters</span><span class="p">,</span>
<span class="lineno">95</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="lineno">96</span> <span class="n">optimized_update</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">optimized_adam_update</span><span class="p">,</span>
<span class="lineno">97</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay_obj</span><span class="p">,</span> <span class="n">amsgrad</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">amsgrad</span><span class="p">)</span>
<span class="lineno">98</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">99</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.adam</span> <span class="kn">import</span> <span class="n">Adam</span>
<span class="lineno">100</span> <span class="k">return</span> <span class="n">Adam</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">parameters</span><span class="p">,</span>
<span class="lineno">101</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="lineno">102</span> <span class="n">optimized_update</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">optimized_adam_update</span><span class="p">,</span>
<span class="lineno">103</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay_obj</span><span class="p">)</span>
<span class="lineno">104</span>
<span class="lineno">105</span>
<span class="lineno">106</span><span class="nd">@option</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">&#39;AdamW&#39;</span><span class="p">)</span>
<span class="lineno">107</span><span class="k">def</span> <span class="nf">_adam_warmup_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">OptimizerConfigs</span><span class="p">):</span>
<span class="lineno">108</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.adam_warmup</span> <span class="kn">import</span> <span class="n">AdamWarmup</span>
<span class="lineno">109</span> <span class="k">return</span> <span class="n">AdamWarmup</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">parameters</span><span class="p">,</span>
<span class="lineno">110</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="lineno">111</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay_obj</span><span class="p">,</span> <span class="n">amsgrad</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">amsgrad</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">warmup</span><span class="p">)</span>
<span class="lineno">112</span>
<span class="lineno">113</span>
<span class="lineno">114</span><span class="nd">@option</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">&#39;RAdam&#39;</span><span class="p">)</span>
<span class="lineno">115</span><span class="k">def</span> <span class="nf">_radam_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">OptimizerConfigs</span><span class="p">):</span>
<span class="lineno">116</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.radam</span> <span class="kn">import</span> <span class="n">RAdam</span>
<span class="lineno">117</span> <span class="k">return</span> <span class="n">RAdam</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">parameters</span><span class="p">,</span>
<span class="lineno">118</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="lineno">119</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay_obj</span><span class="p">,</span> <span class="n">amsgrad</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">amsgrad</span><span class="p">,</span>
<span class="lineno">120</span> <span class="n">degenerated_to_sgd</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">degenerate_to_sgd</span><span class="p">)</span>
<span class="lineno">121</span>
<span class="lineno">122</span>
<span class="lineno">123</span><span class="nd">@option</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">&#39;AdaBelief&#39;</span><span class="p">)</span>
<span class="lineno">124</span><span class="k">def</span> <span class="nf">_ada_belief_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">OptimizerConfigs</span><span class="p">):</span>
<span class="lineno">125</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.ada_belief</span> <span class="kn">import</span> <span class="n">AdaBelief</span>
<span class="lineno">126</span> <span class="k">return</span> <span class="n">AdaBelief</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">parameters</span><span class="p">,</span>
<span class="lineno">127</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="lineno">128</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay_obj</span><span class="p">,</span> <span class="n">amsgrad</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">amsgrad</span><span class="p">,</span>
<span class="lineno">129</span> <span class="n">degenerate_to_sgd</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">degenerate_to_sgd</span><span class="p">,</span>
<span class="lineno">130</span> <span class="n">rectify</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">rectify</span><span class="p">)</span>
<span class="lineno">131</span>
<span class="lineno">132</span>
<span class="lineno">133</span><span class="nd">@option</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">&#39;Noam&#39;</span><span class="p">)</span>
<span class="lineno">134</span><span class="k">def</span> <span class="nf">_noam_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">OptimizerConfigs</span><span class="p">):</span>
<span class="lineno">135</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.noam</span> <span class="kn">import</span> <span class="n">Noam</span>
<span class="lineno">136</span> <span class="k">return</span> <span class="n">Noam</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">parameters</span><span class="p">,</span>
<span class="lineno">137</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="lineno">138</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay_obj</span><span class="p">,</span> <span class="n">amsgrad</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">amsgrad</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">warmup</span><span class="p">,</span>
<span class="lineno">139</span> <span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">140</span>
<span class="lineno">141</span>
<span class="lineno">142</span><span class="nd">@option</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">&#39;Sophia&#39;</span><span class="p">)</span>
<span class="lineno">143</span><span class="k">def</span> <span class="nf">_sophia_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">OptimizerConfigs</span><span class="p">):</span>
<span class="lineno">144</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.sophia</span> <span class="kn">import</span> <span class="n">Sophia</span>
<span class="lineno">145</span> <span class="k">return</span> <span class="n">Sophia</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">parameters</span><span class="p">,</span>
<span class="lineno">146</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="lineno">147</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay_obj</span><span class="p">,</span> <span class="n">rho</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">rho</span><span class="p">)</span>
<span class="lineno">148</span>
<span class="lineno">149</span>
<span class="lineno">150</span><span class="nd">@option</span><span class="p">(</span><span class="n">OptimizerConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">&#39;AdamWarmupCosineDecay&#39;</span><span class="p">)</span>
<span class="lineno">151</span><span class="k">def</span> <span class="nf">_noam_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">OptimizerConfigs</span><span class="p">):</span>
<span class="lineno">152</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.adam_warmup_cosine_decay</span> <span class="kn">import</span> <span class="n">AdamWarmupCosineDecay</span>
<span class="lineno">153</span> <span class="k">return</span> <span class="n">AdamWarmupCosineDecay</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">parameters</span><span class="p">,</span>
<span class="lineno">154</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span>
<span class="lineno">155</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">weight_decay_obj</span><span class="p">,</span> <span class="n">amsgrad</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">amsgrad</span><span class="p">,</span>
<span class="lineno">156</span> <span class="n">warmup</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">warmup</span><span class="p">,</span> <span class="n">total_steps</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">total_steps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>