Files
2024-06-21 19:35:22 +05:30

501 lines
41 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A simple PyTorch implementation/tutorial of Adam optimizer"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Adam Optimizer for Half Precision Training"/>
<meta name="twitter:description" content="A simple PyTorch implementation/tutorial of Adam optimizer"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/optimizers/adam_fp16.html"/>
<meta property="og:title" content="Adam Optimizer for Half Precision Training"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Adam Optimizer for Half Precision Training"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Adam Optimizer for Half Precision Training"/>
<meta property="og:description" content="A simple PyTorch implementation/tutorial of Adam optimizer"/>
<title>Adam Optimizer for Half Precision Training</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/optimizers/adam_fp16.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">optimizers</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/adam_fp16.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Adam Optimizer for Half Precision Training</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Any</span>
<span class="lineno">11</span>
<span class="lineno">12</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Optimizer</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">grad_scaler</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">defaultdict</span><span class="p">,</span> <span class="n">abc</span>
<span class="lineno">17</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers</span> <span class="kn">import</span> <span class="n">WeightDecay</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers.adam</span> <span class="kn">import</span> <span class="n">Adam</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Adam Optimizer for Half Precision Training</h2>
<p>We extend <a href="adam.html">Adam Optimizer</a> but use FP32 to store gradients and moments.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span><span class="k">class</span> <span class="nc">AdamFP16</span><span class="p">(</span><span class="n">Adam</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-3</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">),</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-16</span><span class="p">,</span>
<span class="lineno">30</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="n">WeightDecay</span> <span class="o">=</span> <span class="n">WeightDecay</span><span class="p">(),</span> <span class="n">optimized_update</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">31</span> <span class="n">defaults</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Parameter to store 32 bit gradients. This get populated by the <code class="highlight"><span></span><span class="n">GradScaler</span></code>
defined below. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_fp32</span> <span class="o">=</span> <span class="p">{}</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Call the <a href="adam.html">Adam Optimizer</a> initializer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">betas</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">optimized_update</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<h3>Initialize a parameter state</h3>
<ul><li><code class="highlight"><span></span><span class="n">state</span></code>
is the optimizer state of the parameter (tensor) </li>
<li><code class="highlight"><span></span><span class="n">group</span></code>
stores optimizer attributes of the parameter group </li>
<li><code class="highlight"><span></span><span class="n">param</span></code>
is the parameter tensor <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.902771em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span></span></span></span></span></li></ul>
<p>All the state tensors use FP32.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">param</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>This is the number of optimizer steps taken on the parameter, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Exponential moving average of gradients, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;exp_avg&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">memory_format</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">preserve_format</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Exponential moving average of squared gradient values, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;exp_avg_sq&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">memory_format</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">preserve_format</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Maintain a FP32 copy of the parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;fp32_copy&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<h3>Take an update step for a given parameter tensor</h3>
<ul><li><code class="highlight"><span></span><span class="n">state</span></code>
is the optimizer state of the parameter (tensor) </li>
<li><code class="highlight"><span></span><span class="n">group</span></code>
stores optimizer attributes of the parameter group </li>
<li><code class="highlight"><span></span><span class="n">grad</span></code>
is the current gradient tensor <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqe" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> for the parameter <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.902771em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">param</span></code>
is the parameter tensor <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.902771em;vertical-align:-0.208331em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span></span></span></span></span></span></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="nf">step_param</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">group</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">any</span><span class="p">],</span> <span class="n">grad</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Get the FP32 parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">param_fp32</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;fp32_copy&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Get the FP32 gradients if available </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">grad_fp32</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_fp32</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="lineno">71</span> <span class="k">if</span> <span class="n">grad_fp32</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">72</span> <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_fp32</span><span class="p">[</span><span class="n">param</span><span class="p">]</span>
<span class="lineno">73</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">grad_fp32</span>
<span class="lineno">74</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Otherwise, convert the gradients to FP32 </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Calculate weight decay </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">grad</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">(</span><span class="n">param_fp32</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">group</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Get <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style=""><span class="mord mathnormal" style="">m</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqe" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_mv</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">grad</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Increment <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span> the number of optimizer steps </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Perform <em>Adam</em> update </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="bp">self</span><span class="o">.</span><span class="n">adam_update</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">param_fp32</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Set the parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">param_fp32</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<h2>Gradient Scaler with half precision gradients</h2>
<p>We extend PyTorch gradient scaler to use FP32 gradients.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span><span class="k">class</span> <span class="nc">GradScalerFP16</span><span class="p">(</span><span class="n">grad_scaler</span><span class="o">.</span><span class="n">GradScaler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">def</span> <span class="nf">_unscale_grads_</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">Optimizer</span><span class="p">,</span> <span class="n">inv_scale</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">found_inf</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">102</span> <span class="n">allow_fp16</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="lineno">103</span> <span class="n">per_device_inv_scale</span> <span class="o">=</span> <span class="n">grad_scaler</span><span class="o">.</span><span class="n">_MultiDeviceReplicator</span><span class="p">(</span><span class="n">inv_scale</span><span class="p">)</span>
<span class="lineno">104</span> <span class="n">per_device_found_inf</span> <span class="o">=</span> <span class="n">grad_scaler</span><span class="o">.</span><span class="n">_MultiDeviceReplicator</span><span class="p">(</span><span class="n">found_inf</span><span class="p">)</span>
<span class="lineno">105</span>
<span class="lineno">106</span> <span class="n">per_device_and_dtype_grads</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">list</span><span class="p">))</span> <span class="c1"># type: ignore[var-annotated]</span>
<span class="lineno">107</span>
<span class="lineno">108</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Loop through parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
<span class="lineno">111</span> <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s2">&quot;params&quot;</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Skip non-trainable parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">114</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Not implemented for sparse tensors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">:</span>
<span class="lineno">117</span> <span class="k">raise</span> <span class="ne">NotImplementedError</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>If we are using the <code class="highlight"><span></span><span class="n">AdamFP16</span></code>
optimizer set <code class="highlight"><span></span><span class="n">optimizer</span><span class="o">.</span><span class="n">grad_fp32</span><span class="p">[</span><span class="n">param</span><span class="p">]</span></code>
to the FP32 gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">AdamFP16</span><span class="p">):</span>
<span class="lineno">121</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
<span class="lineno">122</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">grad_fp32</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">grad</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Otherwise, do not convert the gradients to FP32 </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">125</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span>
<span class="lineno">126</span>
<span class="lineno">127</span> <span class="n">per_device_and_dtype_grads</span><span class="p">[</span><span class="n">grad</span><span class="o">.</span><span class="n">device</span><span class="p">][</span><span class="n">grad</span><span class="o">.</span><span class="n">dtype</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">grad</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Unscale all the gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">for</span> <span class="n">device</span><span class="p">,</span> <span class="n">per_dtype_grads</span> <span class="ow">in</span> <span class="n">per_device_and_dtype_grads</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">131</span> <span class="k">for</span> <span class="n">grads</span> <span class="ow">in</span> <span class="n">per_dtype_grads</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
<span class="lineno">132</span> <span class="n">torch</span><span class="o">.</span><span class="n">_amp_foreach_non_finite_check_and_unscale_</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span>
<span class="lineno">133</span> <span class="n">per_device_found_inf</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="lineno">134</span> <span class="n">per_device_inv_scale</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">device</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">return</span> <span class="n">per_device_found_inf</span><span class="o">.</span><span class="n">_per_device_tensors</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>