Files
Varuna Jayasiri 2038b11d29 ja translation
2023-05-10 17:00:29 -04:00

328 lines
42 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="ja">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="これはNoamオプティマイザーのチュートリアル/実装です。Noam Optimizerにはウォームアップ期間があり、その後は学習率が指数関数的に低下します。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="アテンション・イズ・オール・ユー・ニード論文の Noam オプティマイザー"/>
<meta name="twitter:description" content="これはNoamオプティマイザーのチュートリアル/実装です。Noam Optimizerにはウォームアップ期間があり、その後は学習率が指数関数的に低下します。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/optimizers/noam.html"/>
<meta property="og:title" content="アテンション・イズ・オール・ユー・ニード論文の Noam オプティマイザー"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="アテンション・イズ・オール・ユー・ニード論文の Noam オプティマイザー"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="アテンション・イズ・オール・ユー・ニード論文の Noam オプティマイザー"/>
<meta property="og:description" content="これはNoamオプティマイザーのチュートリアル/実装です。Noam Optimizerにはウォームアップ期間があり、その後は学習率が指数関数的に低下します。"/>
<title>アテンション・イズ・オール・ユー・ニード論文の Noam オプティマイザー</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/optimizers/noam.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">optimizers</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/noam.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>ノームオプティマイザー</h1>
<p>これは、論文「<a href="https://papers.labml.ai/paper/1706.03762">必要なのは注意だけ」<a href="https://pytorch.org">で紹介されているオプティマイザーのPyTorch実装です</a></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span>
<span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers</span> <span class="kn">import</span> <span class="n">WeightDecay</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers.amsgrad</span> <span class="kn">import</span> <span class="n">AMSGrad</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>ノームオプティマイザー</h2>
<p>このクラスは、で定義されている Adam オプティマイザを拡張したものです。<a href="adam.html"><code class="highlight"><span></span><span class="n">adam</span><span class="o">.</span><span class="n">py</span></code>
</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">20</span><span class="k">class</span> <span class="nc">Noam</span><span class="p">(</span><span class="n">AMSGrad</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h3>オプティマイザを初期化</h3>
<ul><li><code class="highlight"><span></span><span class="n">params</span></code>
はパラメータのリストです</li>
<li><code class="highlight"><span></span><span class="n">lr</span></code>
は学習率 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">betas</span></code>
(,) <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> のタプルです <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">eps</span></code>
<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="">ϵ</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.19444em;"><span class="mord">^</span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span>またはそれに基づいている <code class="highlight"><span></span><span class="n">optimized_update</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">weight_decay</span></code>
<code class="highlight"><span></span><span class="n">WeightDecay</span></code>
で定義されているクラスのインスタンスです <a href="index.html"><code class="highlight"><span></span><span class="fm">__init__</span><span class="o">.</span><span class="n">py</span></code>
</a></li>
<li>'optimized_update'は追加後に行うことでセカンドモーメントのバイアス補正を最適化するかどうかのフラグです <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="">ϵ</span></span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">amsgrad</span></code>
amsGradを使用するか、プレーンなAdamにフォールバックするかを示すフラグです</li>
<li><code class="highlight"><span></span><span class="n">warmup</span></code>
ウォームアップステップ数</li>
<li><code class="highlight"><span></span><span class="n">d_model</span></code>
モデルサイズ、つまり変圧器の次元数</li>
<li><code class="highlight"><span></span><span class="n">defaults</span></code>
グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利です<code class="highlight"><span></span><span class="n">AdamWarmup</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">),</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-16</span><span class="p">,</span>
<span class="lineno">28</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="n">WeightDecay</span> <span class="o">=</span> <span class="n">WeightDecay</span><span class="p">(),</span>
<span class="lineno">29</span> <span class="n">optimized_update</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">30</span> <span class="n">amsgrad</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">31</span> <span class="n">warmup</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">d_model</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">defaults</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">defaults</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">if</span> <span class="n">defaults</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">defaults</span>
<span class="lineno">50</span> <span class="n">defaults</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">))</span>
<span class="lineno">51</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">optimized_update</span><span class="p">,</span> <span class="n">amsgrad</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span>
<span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<h3>学習率を取得</h3>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.40003em;vertical-align:-0.95003em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqg" style="margin-right:0.0037em">α</span></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord" style=""><span class="mord mathnormal" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="mord mathnormal mtight" style="">o</span><span class="mord mathnormal mtight" style="">d</span><span class="mord mathnormal mtight" style="">e</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.18278000000000005em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord" style=""><span class="mop coloredeq eqb" style=""><span style="">m</span><span style="">i</span><span style="">n</span></span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqb" style=""><span class="delimsizing size3" style=""><span style="">(</span></span></span><span class="mord coloredeq eqb" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.21746em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.89254em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord mathnormal" style="">t</span></span></span><span style="top:-2.85254em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14746000000000004em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mpunct coloredeq eqb" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqb" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.29208em;"><span style="top:-2.2960000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqh" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.814em;"><span style="top:-2.989em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">3/2</span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.704em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord coloredeq eqb" style=""><span class="delimsizing size3" style=""><span style="">)</span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span>はウォームアップステップの数です。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="k">def</span> <span class="nf">get_lr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.40003em;vertical-align:-0.95003em;"></span><span class="mord coloredeq eqb" style=""><span class="mop" style=""><span style="">m</span><span style="">i</span><span style="">n</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="delimsizing size3" style=""><span style="">(</span></span></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.21746em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.89254em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord mathnormal" style="">t</span></span></span><span style="top:-2.85254em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14746000000000004em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.29208em;"><span style="top:-2.2960000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqh" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.814em;"><span style="top:-2.989em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">3/2</span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.704em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord" style=""><span class="delimsizing size3" style=""><span style="">)</span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">factor</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">**</span> <span class="p">(</span><span class="o">-</span><span class="mf">0.5</span><span class="p">),</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">*</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;warmup&#39;</span><span class="p">]</span> <span class="o">**</span> <span class="p">(</span><span class="o">-</span><span class="mf">1.5</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.40003em;vertical-align:-0.95003em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqg" style="margin-right:0.0037em">α</span></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord" style=""><span class="mord mathnormal" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">m</span><span class="mord mathnormal mtight" style="">o</span><span class="mord mathnormal mtight" style="">d</span><span class="mord mathnormal mtight" style="">e</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.18278000000000005em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord" style=""><span class="mop coloredeq eqb" style=""><span style="">m</span><span style="">i</span><span style="">n</span></span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqb" style=""><span class="delimsizing size3" style=""><span style="">(</span></span></span><span class="mord coloredeq eqb" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.21746em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.89254em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord mathnormal" style="">t</span></span></span><span style="top:-2.85254em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.14746000000000004em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mpunct coloredeq eqb" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mord coloredeq eqb" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.29208em;"><span style="top:-2.2960000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqh" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.814em;"><span style="top:-2.989em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">3/2</span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.704em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord coloredeq eqb" style=""><span class="delimsizing size3" style=""><span style="">)</span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="k">return</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_model</span> <span class="o">**</span> <span class="p">(</span><span class="o">-</span><span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="n">factor</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<h3>さまざまなウォームアップとモデルサイズの学習率をプロット</h3>
<p><img alt="Plot of learning rate" src="noam_lr.png"></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span><span class="k">def</span> <span class="nf">_test_noam_lr</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
<span class="lineno">74</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="lineno">75</span> <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">76</span>
<span class="lineno">77</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<span class="lineno">78</span> <span class="n">opts</span> <span class="o">=</span> <span class="p">[</span><span class="n">Noam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">d_model</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="mi">4000</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
<span class="lineno">79</span> <span class="n">Noam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">d_model</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="mi">8000</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
<span class="lineno">80</span> <span class="n">Noam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">d_model</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="mi">2000</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
<span class="lineno">81</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20000</span><span class="p">),</span> <span class="p">[[</span><span class="n">opt</span><span class="o">.</span><span class="n">get_lr</span><span class="p">({</span><span class="s1">&#39;step&#39;</span><span class="p">:</span> <span class="n">i</span><span class="p">},</span> <span class="n">opt</span><span class="o">.</span><span class="n">defaults</span><span class="p">)</span> <span class="k">for</span> <span class="n">opt</span> <span class="ow">in</span> <span class="n">opts</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20000</span><span class="p">)])</span>
<span class="lineno">82</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">([</span><span class="s2">&quot;512:4000&quot;</span><span class="p">,</span> <span class="s2">&quot;512:8000&quot;</span><span class="p">,</span> <span class="s2">&quot;2048:2000&quot;</span><span class="p">])</span>
<span class="lineno">83</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Learning Rate&quot;</span><span class="p">)</span>
<span class="lineno">84</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<span class="lineno">85</span>
<span class="lineno">86</span>
<span class="lineno">87</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">88</span> <span class="n">_test_noam_lr</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>