Files
2024-06-21 19:35:22 +05:30

501 lines
41 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Adam 优化器的一个简单的 PyTorch 实现/教程"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="半精度训练的 Adam Optimizer"/>
<meta name="twitter:description" content="Adam 优化器的一个简单的 PyTorch 实现/教程"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/optimizers/adam_fp16.html"/>
<meta property="og:title" content="半精度训练的 Adam Optimizer"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="半精度训练的 Adam Optimizer"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="半精度训练的 Adam Optimizer"/>
<meta property="og:description" content="Adam 优化器的一个简单的 PyTorch 实现/教程"/>
<title>半精度训练的 Adam Optimizer</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/optimizers/adam_fp16.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">optimizers</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/adam_fp16.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>半精度训练的 Adam Optimizer</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Any</span>
<span class="lineno">11</span>
<span class="lineno">12</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Optimizer</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">grad_scaler</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">defaultdict</span><span class="p">,</span> <span class="n">abc</span>
<span class="lineno">17</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers</span> <span class="kn">import</span> <span class="n">WeightDecay</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers.adam</span> <span class="kn">import</span> <span class="n">Adam</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>半精度训练的 Adam Optimizer</h2>
<p>我们扩展了 <a href="adam.html">Adam Optimizer</a>,但使用 FP32 来存储渐变和时刻。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span><span class="k">class</span> <span class="nc">AdamFP16</span><span class="p">(</span><span class="n">Adam</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-3</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">),</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-16</span><span class="p">,</span>
<span class="lineno">30</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="n">WeightDecay</span> <span class="o">=</span> <span class="n">WeightDecay</span><span class="p">(),</span> <span class="n">optimized_update</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">31</span> <span class="n">defaults</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>用于存储 32 位渐变的参数。这由下面<code class="highlight"><span></span><span class="n">GradScaler</span></code>
定义的填充。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_fp32</span> <span class="o">=</span> <span class="p">{}</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>调用 <a href="adam.html">Adam 优化器</a>初始化器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">optimized_update</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<h3>初始化参数状态</h3>
<ul><li><code class="highlight"><span></span><span class="n">state</span></code>
是参数(张量)的优化器状态</li>
<li><code class="highlight"><span></span><span class="n">group</span></code>
存储参数组的优化程序属性</li>
<li><code class="highlight"><span></span><span class="n">param</span></code>
是参数张量<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.902771em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span></span></span></span></span></li></ul>
<p>所有状态张量都使用 FP32。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">param</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>这是优化器对参数采取的步骤数,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>梯度的指数移动平均线,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;exp_avg&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">memory_format</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">preserve_format</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>梯度平方值的指数移动平均线,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;exp_avg_sq&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">memory_format</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">preserve_format</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>维护参数的 FP32 副本</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;fp32_copy&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<h3>对给定参数张量执行更新步骤</h3>
<ul><li><code class="highlight"><span></span><span class="n">state</span></code>
是参数(张量)的优化器状态</li>
<li><code class="highlight"><span></span><span class="n">group</span></code>
存储参数组的优化程序属性</li>
<li><code class="highlight"><span></span><span class="n">grad</span></code>
是参数的当前梯<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqe" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>度张量<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.902771em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">param</span></code>
是参数张量<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.902771em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span></span></span></span></span></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="nf">step_param</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>获取 FP32 参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">param_fp32</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;fp32_copy&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>获取 FP32 渐变(如果有)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">grad_fp32</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_fp32</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="lineno">71</span> <span class="k">if</span> <span class="n">grad_fp32</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">72</span> <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_fp32</span><span class="p">[</span><span class="n">param</span><span class="p">]</span>
<span class="lineno">73</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">grad_fp32</span>
<span class="lineno">74</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>否则,将渐变转换为 FP32</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>计算体重衰减</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">grad</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">(</span><span class="n">param_fp32</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">group</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>获取<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_mv</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">grad</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span>增加优化器步数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>执行 <em>Adam</em> 更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="bp">self</span><span class="o">.</span><span class="n">adam_update</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">param_fp32</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>设置参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">param_fp32</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<h2>具有半精度渐变的渐变缩放器</h2>
<p>我们将 PyTorch 梯度缩放器扩展为使用 FP32 渐变。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span><span class="k">class</span> <span class="nc">GradScalerFP16</span><span class="p">(</span><span class="n">grad_scaler</span><span class="o">.</span><span class="n">GradScaler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">def</span> <span class="nf">_unscale_grads_</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">Optimizer</span><span class="p">,</span> <span class="n">inv_scale</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">found_inf</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">102</span> <span class="n">allow_fp16</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="lineno">103</span> <span class="n">per_device_inv_scale</span> <span class="o">=</span> <span class="n">grad_scaler</span><span class="o">.</span><span class="n">_MultiDeviceReplicator</span><span class="p">(</span><span class="n">inv_scale</span><span class="p">)</span>
<span class="lineno">104</span> <span class="n">per_device_found_inf</span> <span class="o">=</span> <span class="n">grad_scaler</span><span class="o">.</span><span class="n">_MultiDeviceReplicator</span><span class="p">(</span><span class="n">found_inf</span><span class="p">)</span>
<span class="lineno">105</span>
<span class="lineno">106</span> <span class="n">per_device_and_dtype_grads</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">list</span><span class="p">))</span> <span class="c1"># type: ignore[var-annotated]</span>
<span class="lineno">107</span>
<span class="lineno">108</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>循环浏览参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
<span class="lineno">111</span> <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;params&quot;</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>跳过不可训练的参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">114</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>未针对稀疏张量实现</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">:</span>
<span class="lineno">117</span> <span class="k">raise</span> <span class="ne">NotImplementedError</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>如果我们使用设置为<code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">grad_fp32</span><span class="p">[</span><span class="n">param</span><span class="p">]</span></code>
FP32 渐变的<code class="highlight"><span></span><span class="n">AdamFP16</span></code>
优化器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">AdamFP16</span><span class="p">):</span>
<span class="lineno">121</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
<span class="lineno">122</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">grad_fp32</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">grad</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>否则,不要将渐变转换为 FP32</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">125</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span>
<span class="lineno">126</span>
<span class="lineno">127</span> <span class="n">per_device_and_dtype_grads</span><span class="p">[</span><span class="n">grad</span><span class="o">.</span><span class="n">device</span><span class="p">][</span><span class="n">grad</span><span class="o">.</span><span class="n">dtype</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">grad</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>取消缩放所有渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">for</span> <span class="n">device</span><span class="p">,</span> <span class="n">per_dtype_grads</span> <span class="ow">in</span> <span class="n">per_device_and_dtype_grads</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">131</span> <span class="k">for</span> <span class="n">grads</span> <span class="ow">in</span> <span class="n">per_dtype_grads</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="lineno">132</span> <span class="n">torch</span><span class="o">.</span><span class="n">_amp_foreach_non_finite_check_and_unscale_</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span>
<span class="lineno">133</span> <span class="n">per_device_found_inf</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="lineno">134</span> <span class="n">per_device_inv_scale</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">device</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">return</span> <span class="n">per_device_found_inf</span><span class="o">.</span><span class="n">_per_device_tensors</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>