Files
2024-06-21 19:35:22 +05:30

1850 lines
115 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This is an implementation of Zero-DP Memory Optimization written in PyTorch."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Zero-DP Memory Optimization"/>
<meta name="twitter:description" content="This is an implementation of Zero-DP Memory Optimization written in PyTorch."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/scaling/zero3/index.html"/>
<meta property="og:title" content="Zero-DP Memory Optimization"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Zero-DP Memory Optimization"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Zero-DP Memory Optimization"/>
<meta property="og:description" content="This is an implementation of Zero-DP Memory Optimization written in PyTorch."/>
<title>Zero-DP Memory Optimization</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/scaling/zero3/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">scaling</a>
<a class="parent" href="index.html">zero3</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/scaling/zero3/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Zero-DP Memory Optimization</h1>
<p>This is an implementation of Zero-DP introduced in the paper <a href="https://arxiv.org/abs/1910.02054">ZeRO: Memory Optimization Towards Training A Trillion Parameter Models</a>,</p>
<p>It keeps shards of the optimizer state, gradients and parameters into multiple devices/nodes. It reduces the memory consumption to <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.4608599999999998em;vertical-align:-0.4508599999999999em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.01em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqg" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:-0.10903em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.485em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mtight" style="">2</span><span class="mbin mtight" style="">+</span><span class="mord mtight" style="">2</span></span><span class="mbin mtight">+</span><span class="mord mtight coloredeq eqi" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span></span><span class="mclose mtight">)</span><span class="mord mtight coloredeq eqf" style=""><span class="mord mtight" style="">Ψ</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.4508599999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> of the original model, where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style="">Ψ</span></span></span></span></span></span> is the number of parameters, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> is the number of shards, and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span></span></span></span></span> is number of optimizer bytes per parameter. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style="">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">2</span></span></span></span></span></span> are the parameter and gradient memory assuming 16-bit precision; i.e. 2 bytes per parameter and gradient. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">12</span></span></span></span></span> for Adam optimizer because it maintains a copy of parameters, and two moments per parameter in fp32.</p>
<p>The communication volume of Zero-DP is <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.02778em;">O</span><span class="mopen">(</span><span class="mord">3</span><span class="mord coloredeq eqf" style=""><span class="mord" style="">Ψ</span></span><span class="mclose">)</span></span></span></span></span>. For comparison data-parallel training has a communication volume of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.02778em;">O</span><span class="mopen">(</span><span class="mord">2</span><span class="mord coloredeq eqf" style=""><span class="mord" style="">Ψ</span></span><span class="mclose">)</span></span></span></span></span>.</p>
<p>Although this is named <code class="highlight"><span></span><span class="n">Zero3</span></code>
, we have only implemented the Zero-DP part of it and not the Zero-R memory optimizations which target residual memory consumption. Out implementation supports training only a subset of parameters.</p>
<p>This implementation is inspired by <a href="https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html">Fairscale FSDP</a>.</p>
<p><a href="finetune_neox.html">Here&#x27;s a script to fine-tune</a> GPT NeoX using Zero-DP memory optimization.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span></span><span class="kn">import</span> <span class="nn">functools</span>
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="lineno">34</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">36</span><span class="kn">import</span> <span class="nn">torch.distributed</span> <span class="k">as</span> <span class="nn">dist</span>
<span class="lineno">37</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Zero3 Layer</h2>
<p>Each layer of the model (or a combination of a few consecutive layers) should be wrapped in this module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span><span class="k">class</span> <span class="nc">Zero3Layer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Each shard keeps parameters in <code class="highlight"><span></span><span class="n">chunk</span></code>
list. The <code class="highlight"><span></span><span class="n">chunk</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
is for trainable parameters and <code class="highlight"><span></span><span class="n">chunk</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></code>
is for fixed parameters. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">chunk</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>This is the sizes of the chunks in <code class="highlight"><span></span><span class="n">chunk</span></code>
list. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="n">chunk_size</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>The first chunk is for trainable parameters. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">TRAINING_PARAMS_IDX</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>This is the list of parameters split into lists as trainable and fixed parameters. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">param_refs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>CUDA stream to featch parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">fetch_stream</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>CUDA stream to backup/accumulate gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">backup_stream</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>List of layers right before this layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">prev_layer</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s1">&#39;Zero3Layer&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>List of layers right after this layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">next_layer</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s1">&#39;Zero3Layer&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>The position of the current layer; used this for debugging logs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="n">layer_idx</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Whether parameters have been fetched </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">is_fetched</span><span class="p">:</span> <span class="nb">bool</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Device of the layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Data type of the layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>The module to be wrapped </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Number of nodes/devices the data is sharded across </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">world_size</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">module</span></code>
The module to be wrapped. </li>
<li><code class="highlight"><span></span><span class="n">rank</span></code>
The rank of the current node. </li>
<li><code class="highlight"><span></span><span class="n">world_size</span></code>
The number of nodes/devices the data is sharded across. </li>
<li><code class="highlight"><span></span><span class="n">device</span></code>
The device of the layer. </li>
<li><code class="highlight"><span></span><span class="n">dtype</span></code>
The data type of the layer.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">world_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Initialize the properties </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">module</span>
<span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">prev_layer</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">96</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_layer</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span> <span class="o">=</span> <span class="kc">False</span>
<span class="lineno">98</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="n">world_size</span>
<span class="lineno">99</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
<span class="lineno">100</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">102</span>
<span class="lineno">103</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Collect all the parameters of the layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">all_param_refs</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">()]</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Store the shape of the parameters because we need it later to reconstruct them </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span><span class="p">:</span>
<span class="lineno">109</span> <span class="n">p</span><span class="o">.</span><span class="n">_orig_shape</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>All parameters should have the same type </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span><span class="p">:</span>
<span class="lineno">113</span> <span class="k">assert</span> <span class="n">p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">dtype</span><span class="p">,</span> <span class="s2">&quot;All parameters should have same dtype&quot;</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Separate parameters as trainable and fixed </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span> <span class="o">=</span> <span class="p">[[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">],</span>
<span class="lineno">117</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">p</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">]]</span>
<span class="lineno">118</span> <span class="k">del</span> <span class="n">all_param_refs</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>The <code class="highlight"><span></span><span class="n">rank</span> <span class="o">=</span> <span class="mi">0</span></code>
node will calculate the size each device/node should store, and distribute the parameters accordingly. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">if</span> <span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Merge and pad trainable (<code class="highlight"><span></span><span class="n">merged_params</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
) and fixed (<code class="highlight"><span></span><span class="n">merged_params</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></code>
) parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">merged_params</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_merge_and_pad_params</span><span class="p">(</span><span class="n">ps</span><span class="p">)</span> <span class="k">for</span> <span class="n">ps</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Calculate the chunk sizes of trainable and fixed params </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span> <span class="o">=</span> <span class="p">[(</span><span class="nb">len</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="o">//</span> <span class="n">world_size</span> <span class="k">if</span> <span class="n">p</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">merged_params</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Broadcast the sizes </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">dist</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">),</span> <span class="n">src</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="lineno">129</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Create an empty tensor to receive the sizes </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="n">chunk_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Receive the sizes </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">dist</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="lineno">134</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span> <span class="o">=</span> <span class="n">chunk_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Create parameters for trainable (<code class="highlight"><span></span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
) and fixed (<code class="highlight"><span></span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></code>
) parameters to be stored in current device/node </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk</span> <span class="o">=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="n">s</span><span class="p">,)),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="n">i</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">)</span>
<span class="lineno">139</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>An empty tensor to receive the trainable and fixed parameters combined </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">chunk</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">),))</span>
<span class="lineno">143</span>
<span class="lineno">144</span> <span class="k">if</span> <span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Concatenate both trainable and fixed params </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">all_params</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">p</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">world_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">merged_params</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">147</span> <span class="k">del</span> <span class="n">merged_params</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Scatter them to all the nodes/devices </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">dist</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">chunk</span><span class="p">,</span> <span class="nb">list</span><span class="p">(</span><span class="n">all_params</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">))))</span>
<span class="lineno">151</span> <span class="k">del</span> <span class="n">all_params</span>
<span class="lineno">152</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Receive the parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">dist</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">chunk</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Collect the chunk data </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">chunk</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)</span>
<span class="lineno">158</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">chunk</span><span class="p">):</span>
<span class="lineno">159</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">c</span>
<span class="lineno">160</span> <span class="k">del</span> <span class="n">chunk</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Cleanup the normal parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cleanup_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Add a backward hook. This gets called when the gradients relative to the module are computed. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_ref</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_full_backward_hook</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook</span><span class="p">)</span> <span class="c1"># type: ignore</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<h4>Merge all the parameters and pad it so that it&#x27;s divisible by <code class="highlight"><span></span><span class="n">world_size</span></code>
.</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="k">def</span> <span class="nf">_merge_and_pad_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Total number of parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">size</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>If it is not divisible by <code class="highlight"><span></span><span class="n">world_size</span></code>
, pad it </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">if</span> <span class="n">size</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">177</span> <span class="n">padding_fixed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">-</span> <span class="p">(</span><span class="n">size</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Otherwise, no need to pad </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">180</span> <span class="n">padding_fixed</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Create an empty padding tensor </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">padding</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="n">padding_fixed</span><span class="p">,))</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Concatenate all the parameters and pad it </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">p</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">padding</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<h3>Get trainable chunk/shard of the parameters.</h3>
<p>This is what we pass on to the optimizer on the current node.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="k">def</span> <span class="nf">get_trainable_chunk</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Return and empty list if there are no trainable parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">])</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">194</span> <span class="k">return</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Return the trainable chunk as a list </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<h4>Create an empty tensor of the given shape.</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="k">def</span> <span class="nf">_empty</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<h4>Cleanup the parameter data</h4>
<p>This will release all the memory used by the layer parameters.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">206</span> <span class="k">def</span> <span class="nf">_cleanup_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Set the flag to indicate that the parameters are not fetched </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Iterate through all parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="k">for</span> <span class="n">ps</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">:</span>
<span class="lineno">218</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">ps</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Wait for operations on the parameters to complete before any new operations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>Check to make sure the parameter is not sharing storage with anything else </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="k">assert</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">storage_offset</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;The tensor is not the sole occupant of the storage.&quot;</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Resize the storage to <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span></span></span></span>. This will release the memory used by the parameter.</p>
<p><strong>Setting <code class="highlight"><span></span><span class="n">p</span><span class="o">.</span><span class="n">data</span></code>
will not release the memory, since the autograd graph keeps a reference to it.</strong> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">storage</span><span class="p">()</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># This is what actually clears the memory</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Make sure the parameter has no gradient data </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="k">assert</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;Gradients should be None&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<h3>Fetch the parameters from all shards</h3>
<p>This will fetch all the parameter data from all the nodes and rebuild the parameters on each node.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">231</span> <span class="k">def</span> <span class="nf">fetch_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>Skip is already fetched </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span><span class="p">:</span>
<span class="lineno">240</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Set the flag </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Skip if there&#x27;s nothing to fetch or share. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="k">if</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">247</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Use <code class="highlight"><span></span><span class="n">fetch_stream</span></code>
to fetch the parameters from all the shards </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>Create an empty tensor to receive the parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">*</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">),))</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>Split the continuous buffer into the number of nodes. These splits are views of `buffer&#x27;. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">buffers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">buffer</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>Concatenate both trainable and fixed chunks </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Gather the parameters from all the nodes/devices </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">260</span> <span class="n">dist</span><span class="o">.</span><span class="n">all_gather</span><span class="p">(</span><span class="n">buffers</span><span class="p">,</span> <span class="n">chunk</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>Split the gathered parameters into the trainable and fixed chunks </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="n">params</span> <span class="o">=</span> <span class="n">buffer</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">))</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p>Wait for the gather operation to complete and then clear the references to the buffers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span> <span class="n">buffer</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">266</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">buffers</span><span class="p">:</span>
<span class="lineno">267</span> <span class="n">b</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">268</span> <span class="n">buffer</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">269</span> <span class="k">del</span> <span class="n">buffer</span>
<span class="lineno">270</span> <span class="k">del</span> <span class="n">buffers</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>Reshape the trainable and fixed parameters to continuous tensors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">273</span> <span class="n">params</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>Collect the individual parameter tensors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span> <span class="k">for</span> <span class="n">cont</span><span class="p">,</span> <span class="n">ps</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>If there are no parameters, skip </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">278</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">ps</span><span class="p">:</span>
<span class="lineno">279</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>Offset of the continuous tensor </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">offset</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>Iterate through model parameters and assign the values from the continuous tensor </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">284</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">ps</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>Original parameter shape </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">286</span> <span class="n">shape</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">_orig_shape</span> <span class="c1"># type: ignore[attr-defined]</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>Change the storage size of the parameter. This was set to <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span></span></span></span> when we cleaned up the parameters. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">storage</span><span class="p">()</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>Assign the values from the continuous tensor </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">290</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">cont</span><span class="p">[</span><span class="n">offset</span><span class="p">:</span> <span class="n">offset</span> <span class="o">+</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>Wait for the operations to complete before other operations can be performed </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>Update the offset </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="n">offset</span> <span class="o">+=</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>Wait for the operation to complete before other operations can be performed </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">297</span> <span class="n">cont</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">300</span> <span class="k">del</span> <span class="n">params</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<h3>Forward pass</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">302</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<p>Fetch all the parameters of the current node. This gets called by the previous layer so this call is just to make sure parameters are fetched. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">309</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<p>Wait for parameter fetching to complete. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>Start fetching parameters of the proceeding layers, so that they will fetch them which the current layer does its computations. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">316</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_layer</span><span class="p">:</span>
<span class="lineno">317</span> <span class="n">layer</span><span class="o">.</span><span class="n">fetch_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>Add backward hooks to the parameters of the current layer if autograd is enabled. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">320</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">is_grad_enabled</span><span class="p">():</span>
<span class="lineno">321</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_backward_hooks</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>Compute the outputs of the current layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">324</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>Cleanup the parameters of the layer.</p>
<p><em>Skip cleaning up if autograd is enabled and this is the last layer in the network, because we will need to fetch the parameters again for the backward pass.</em> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">330</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">is_grad_enabled</span><span class="p">()</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_layer</span><span class="p">:</span>
<span class="lineno">331</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cleanup_params</span><span class="p">()</span>
<span class="lineno">332</span>
<span class="lineno">333</span> <span class="k">return</span> <span class="n">res</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<h4>Add backward hooks to the parameters of the current layer.</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">335</span> <span class="k">def</span> <span class="nf">_add_backward_hooks</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>Number of backward hooks added </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">341</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>Loop through trainable parameters of the current layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">344</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<p>Make sure a hook hasn&#x27;t already been added </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">346</span> <span class="k">assert</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s2">&quot;_hook_handle&quot;</span><span class="p">),</span> <span class="s1">&#39;Parameter has already been hooked&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>Use <code class="highlight"><span></span><span class="n">expand_as</span></code>
to create an autograd step which we can intercept </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">348</span> <span class="n">p_tmp</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">p</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>Get a handle to add the backward hook. <a href="https://amsword.medium.com/understanding-pytorchs-autograd-with-grad-fn-and-next-functions-b2c4836daa00">This blog discusses about <code class="highlight"><span></span><span class="n">grad_acc</span></code>
</a>. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">351</span> <span class="n">grad_acc</span> <span class="o">=</span> <span class="n">p_tmp</span><span class="o">.</span><span class="n">grad_fn</span><span class="o">.</span><span class="n">next_functions</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>Add the backward hook </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">353</span> <span class="n">handle</span> <span class="o">=</span> <span class="n">grad_acc</span><span class="o">.</span><span class="n">register_hook</span><span class="p">(</span>
<span class="lineno">354</span> <span class="n">functools</span><span class="o">.</span><span class="n">partial</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_post_backward_hook</span><span class="p">,</span> <span class="n">p</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>Keep a reference to the handle </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">356</span> <span class="n">p</span><span class="o">.</span><span class="n">_hook_handle</span> <span class="o">=</span> <span class="n">handle</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>Increment the number of hooks added </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">358</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<h4>Handle a backward event</h4>
<p>This gets called by parameter backward hooks and the module backward hook.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">360</span> <span class="k">def</span> <span class="nf">_backward_event</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>Decrement the hooks counter </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">368</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">-=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p>If all the hooks (including the module hook) have been called, then we can back up gradients and clean up the parameters. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">372</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
<span class="lineno">373</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backup_grads</span><span class="p">()</span>
<span class="lineno">374</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cleanup_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>Start fetch parameters of the previous layer, because autograd will next process the gradients of it. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">377</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">prev_layer</span><span class="p">:</span>
<span class="lineno">378</span> <span class="n">layer</span><span class="o">.</span><span class="n">fetch_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-98'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-98'>#</a>
</div>
<h4>Parameter backward hook</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">380</span> <span class="k">def</span> <span class="nf">_post_backward_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-99'>
<div class='docs'>
<div class='section-link'>
<a href='#section-99'>#</a>
</div>
<p>Remove the handle from the parameter </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">385</span> <span class="n">p</span><span class="o">.</span><span class="n">_hook_handle</span><span class="o">.</span><span class="n">remove</span><span class="p">()</span> <span class="c1"># type: ignore[attr-defined]</span>
<span class="lineno">386</span> <span class="nb">delattr</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s2">&quot;_hook_handle&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-100'>
<div class='docs'>
<div class='section-link'>
<a href='#section-100'>#</a>
</div>
<p>Handle a backward event </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">389</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_event</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-101'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-101'>#</a>
</div>
<h4>Module backward hook</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">391</span> <span class="k">def</span> <span class="nf">_backward_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-102'>
<div class='docs'>
<div class='section-link'>
<a href='#section-102'>#</a>
</div>
<p>Handle a backward event </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">396</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_event</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-103'>
<div class='docs'>
<div class='section-link'>
<a href='#section-103'>#</a>
</div>
<p>The previous layer will start computing gradients. We need to make sure it has finished fetching params. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">399</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-104'>
<div class='docs'>
<div class='section-link'>
<a href='#section-104'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">402</span> <span class="k">return</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-105'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-105'>#</a>
</div>
<h3>Backup the gradients of the current layer</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">404</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">405</span> <span class="k">def</span> <span class="nf">_backup_grads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-106'>
<div class='docs'>
<div class='section-link'>
<a href='#section-106'>#</a>
</div>
<p>Skip if there are no trainable parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">410</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">411</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-107'>
<div class='docs'>
<div class='section-link'>
<a href='#section-107'>#</a>
</div>
<p>Use the backup stream to backup the gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">414</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-108'>
<div class='docs'>
<div class='section-link'>
<a href='#section-108'>#</a>
</div>
<p>Buffer to store the gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">416</span> <span class="n">buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">],))</span></pre></div>
</div>
</div>
<div class='section' id='section-109'>
<div class='docs'>
<div class='section-link'>
<a href='#section-109'>#</a>
</div>
<p>Split the continuous buffer into number of nodes. These splits are views of `buffer&#x27;. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">418</span> <span class="n">buffers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">buffer</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-110'>
<div class='docs'>
<div class='section-link'>
<a href='#section-110'>#</a>
</div>
<p>Offset of the continuous buffer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">421</span> <span class="n">offset</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-111'>
<div class='docs'>
<div class='section-link'>
<a href='#section-111'>#</a>
</div>
<p>Iterate through trainable parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">423</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-112'>
<div class='docs'>
<div class='section-link'>
<a href='#section-112'>#</a>
</div>
<p>Collect gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">425</span> <span class="n">shape</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">_orig_shape</span> <span class="c1"># type: ignore[attr-defined]</span>
<span class="lineno">426</span> <span class="n">buffer</span><span class="p">[</span><span class="n">offset</span><span class="p">:</span> <span class="n">offset</span> <span class="o">+</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()]</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-113'>
<div class='docs'>
<div class='section-link'>
<a href='#section-113'>#</a>
</div>
<p>Update the offset </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">428</span> <span class="n">offset</span> <span class="o">+=</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-114'>
<div class='docs'>
<div class='section-link'>
<a href='#section-114'>#</a>
</div>
<p>Clean the gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">430</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-115'>
<div class='docs'>
<div class='section-link'>
<a href='#section-115'>#</a>
</div>
<p>Empty tensor to accumulate the gradients of the current shard </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">433</span> <span class="n">grad</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">],))</span></pre></div>
</div>
</div>
<div class='section' id='section-116'>
<div class='docs'>
<div class='section-link'>
<a href='#section-116'>#</a>
</div>
<p>Accumulate the gradients of each shard. It scatters the buffers across the nodes, and each node accumulates (reduces) the tensors it receives. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">436</span> <span class="n">dist</span><span class="o">.</span><span class="n">reduce_scatter</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">buffers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-117'>
<div class='docs'>
<div class='section-link'>
<a href='#section-117'>#</a>
</div>
<p>Wait for the operation to complete and then clear the references to the buffers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">439</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">buffers</span><span class="p">:</span>
<span class="lineno">440</span> <span class="n">b</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">441</span> <span class="n">buffer</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">442</span> <span class="k">del</span> <span class="n">buffer</span>
<span class="lineno">443</span> <span class="k">del</span> <span class="n">buffers</span></pre></div>
</div>
</div>
<div class='section' id='section-118'>
<div class='docs'>
<div class='section-link'>
<a href='#section-118'>#</a>
</div>
<p>Set the chunk gradients. This is what the optimizer sees. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">446</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span>
<span class="lineno">447</span> <span class="k">del</span> <span class="n">grad</span></pre></div>
</div>
</div>
<div class='section' id='section-119'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-119'>#</a>
</div>
<h2>Sequential module for <code class="highlight"><span></span><span class="n">Zero3Layer</span></code>
layers</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">450</span><span class="k">class</span> <span class="nc">Zero3Sequential</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-120'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-120'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">modules</span></code>
List of <code class="highlight"><span></span><span class="n">Zero3Layer</span></code>
layers</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">454</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Zero3Layer</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-121'>
<div class='docs'>
<div class='section-link'>
<a href='#section-121'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">458</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-122'>
<div class='docs'>
<div class='section-link'>
<a href='#section-122'>#</a>
</div>
<p>CUDA stream to fetch parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">461</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-123'>
<div class='docs'>
<div class='section-link'>
<a href='#section-123'>#</a>
</div>
<p>CUDA stream to back up (accumulate) gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">463</span> <span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-124'>
<div class='docs'>
<div class='section-link'>
<a href='#section-124'>#</a>
</div>
<p>Set the streams and preceding and proceeding layers for each <code class="highlight"><span></span><span class="n">Zero3Layer</span></code>
layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">466</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">modules</span><span class="p">)):</span></pre></div>
</div>
</div>
<div class='section' id='section-125'>
<div class='docs'>
<div class='section-link'>
<a href='#section-125'>#</a>
</div>
<p>Set layer index </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">468</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">=</span> <span class="n">i</span></pre></div>
</div>
</div>
<div class='section' id='section-126'>
<div class='docs'>
<div class='section-link'>
<a href='#section-126'>#</a>
</div>
<p>Set streams </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">470</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">fetch_stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span>
<span class="lineno">471</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">backup_stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span></pre></div>
</div>
</div>
<div class='section' id='section-127'>
<div class='docs'>
<div class='section-link'>
<a href='#section-127'>#</a>
</div>
<p>Set proceeding layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">473</span> <span class="k">if</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">modules</span><span class="p">):</span>
<span class="lineno">474</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">next_layer</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">modules</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-128'>
<div class='docs'>
<div class='section-link'>
<a href='#section-128'>#</a>
</div>
<p>Set preceding layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">476</span> <span class="k">if</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">477</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">prev_layer</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">modules</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-129'>
<div class='docs'>
<div class='section-link'>
<a href='#section-129'>#</a>
</div>
<p>Store list of modules </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">480</span> <span class="bp">self</span><span class="o">.</span><span class="n">module_list</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">modules</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-130'>
<div class='docs'>
<div class='section-link'>
<a href='#section-130'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">482</span> <span class="k">def</span> <span class="nf">get_trainable_chunk</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-131'>
<div class='docs'>
<div class='section-link'>
<a href='#section-131'>#</a>
</div>
<p>Return the list of trainable chunks from each layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">484</span> <span class="k">return</span> <span class="nb">sum</span><span class="p">([</span><span class="n">m</span><span class="o">.</span><span class="n">get_trainable_chunk</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">module_list</span><span class="p">],</span> <span class="p">[])</span></pre></div>
</div>
</div>
<div class='section' id='section-132'>
<div class='docs'>
<div class='section-link'>
<a href='#section-132'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">486</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-133'>
<div class='docs'>
<div class='section-link'>
<a href='#section-133'>#</a>
</div>
<p>Make sure gradient back up is complete </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">488</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-134'>
<div class='docs'>
<div class='section-link'>
<a href='#section-134'>#</a>
</div>
<p>Forward pass </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">491</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">module_list</span><span class="p">:</span>
<span class="lineno">492</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-135'>
<div class='docs'>
<div class='section-link'>
<a href='#section-135'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">495</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>