Files
Varuna Jayasiri 1c14551a19 zh
2023-02-28 08:40:22 +05:30

1851 lines
114 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="这是用 PyTorch 编写的零 DP 内存优化的实现。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="零 DP 内存优化"/>
<meta name="twitter:description" content="这是用 PyTorch 编写的零 DP 内存优化的实现。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/scaling/zero3/index.html"/>
<meta property="og:title" content="零 DP 内存优化"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="零 DP 内存优化"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="零 DP 内存优化"/>
<meta property="og:description" content="这是用 PyTorch 编写的零 DP 内存优化的实现。"/>
<title>零 DP 内存优化</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/scaling/zero3/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">scaling</a>
<a class="parent" href="index.html">zero3</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/scaling/zero3/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>零 DP 内存优化</h1>
<p>这是《零<a href="https://papers.labml.ai/paper/1910.02054">:训练一万亿个参数模型的内存优化》一文中介绍的零 DP 的实现</a></p>
<p>它将优化器状态、梯度和参数的分片保存到多个设备/节点中。它减少了原始模型<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.4608599999999998em;vertical-align:-0.4508599999999999em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.01em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqg" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:-0.10903em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.485em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mtight" style="">2</span><span class="mbin mtight" style="">+</span><span class="mord mtight" style="">2</span></span><span class="mbin mtight">+</span><span class="mord mtight coloredeq eqi" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span></span><span class="mclose mtight">)</span><span class="mord mtight coloredeq eqf" style=""><span class="mord mtight" style="">Ψ</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.4508599999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span>的内存消耗,其中<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style="">Ψ</span></span></span></span></span></span>是参数的数量,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>是分片的数量,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span></span></span></span></span>是每个参数的优化器字节数。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style="">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">2</span></span></span></span></span></span>是假设精度为 16 位的参数和梯度存储器;即每个参数和梯度为 2 个字节。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">12</span></span></span></span></span>对于 Adam 优化器,因为它维护参数的副本,在 fp32 中每个参数两个时刻。</p>
<p>零 DP 的通信量为<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.02778em;">O</span><span class="mopen">(</span><span class="mord">3</span><span class="mord coloredeq eqf" style=""><span class="mord" style="">Ψ</span></span><span class="mclose">)</span></span></span></span></span>。比较而言,数据并行训练的通信量为<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.02778em;">O</span><span class="mopen">(</span><span class="mord">2</span><span class="mord coloredeq eqf" style=""><span class="mord" style="">Ψ</span></span><span class="mclose">)</span></span></span></span></span></p>
<p>尽管它被命名了<code class="highlight"><span></span><span class="n">Zero3</span></code>
,但我们只实现了其中的零 DP 部分,没有实现针对剩余内存消耗的 Zero-R 内存优化。Out 实现仅支持训练一部分参数。</p>
<p>此实施的灵感来自<a href="https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html">公平规模的财务安全发展计划</a></p>
<p><a href="finetune_neox.html">以下是使用零 DP 内存优化微调 GPT NeoX 的脚本</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span></span><span class="kn">import</span> <span class="nn">functools</span>
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="lineno">34</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">36</span><span class="kn">import</span> <span class="nn">torch.distributed</span> <span class="k">as</span> <span class="nn">dist</span>
<span class="lineno">37</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Zero3 层</h2>
<p>模型的每一层(或几个连续层的组合)都应该包裹在这个模块中。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span><span class="k">class</span> <span class="nc">Zero3Layer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>每个分片都将参数保存在<code class="highlight"><span></span><span class="n">chunk</span></code>
列表中。用于<code class="highlight"><span></span><span class="n">chunk</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
可训练的参数,<code class="highlight"><span></span><span class="n">chunk</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></code>
用于固定参数。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">chunk</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>这是<code class="highlight"><span></span><span class="n">chunk</span></code>
列表中区块的大小。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="n">chunk_size</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>第一个区块用于可训练的参数。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">TRAINING_PARAMS_IDX</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>这是分为可训练参数和固定参数的列表的参数列表。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">param_refs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>CUDA 流到精选参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">fetch_stream</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>用于备份/累积梯度的 CUDA 流</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">backup_stream</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>此图层之前的图层列表</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">prev_layer</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s1">&#39;Zero3Layer&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>紧随此图层之后的图层列表</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">next_layer</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s1">&#39;Zero3Layer&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>当前层的位置;用于调试日志</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="n">layer_idx</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>参数是否已获取</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">is_fetched</span><span class="p">:</span> <span class="nb">bool</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>该层的设备</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>图层的数据类型</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>要封装的模块</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>分片数据的节点/设备数量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">world_size</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">module</span></code>
要封装的模块。</li>
<li><code class="highlight"><span></span><span class="n">rank</span></code>
当前节点的等级。</li>
<li><code class="highlight"><span></span><span class="n">world_size</span></code>
分片数据的节点/设备数量。</li>
<li><code class="highlight"><span></span><span class="n">device</span></code>
层的设备。</li>
<li><code class="highlight"><span></span><span class="n">dtype</span></code>
图层的数据类型。</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">world_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>初始化属性</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">module</span>
<span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">prev_layer</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">96</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_layer</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span> <span class="o">=</span> <span class="kc">False</span>
<span class="lineno">98</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="n">world_size</span>
<span class="lineno">99</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
<span class="lineno">100</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">102</span>
<span class="lineno">103</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>收集图层的所有参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">all_param_refs</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">()]</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>存储参数的形状,因为我们稍后需要它来重建它们</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span><span class="p">:</span>
<span class="lineno">109</span> <span class="n">p</span><span class="o">.</span><span class="n">_orig_shape</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>所有参数都应具有相同的类型</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span><span class="p">:</span>
<span class="lineno">113</span> <span class="k">assert</span> <span class="n">p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">dtype</span><span class="p">,</span> <span class="s2">&quot;All parameters should have same dtype&quot;</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>将参数分为可训练和固定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span> <span class="o">=</span> <span class="p">[[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">],</span>
<span class="lineno">117</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">p</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">]]</span>
<span class="lineno">118</span> <span class="k">del</span> <span class="n">all_param_refs</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">rank</span> <span class="o">=</span> <span class="mi">0</span></code>
节点将计算每个设备/节点应存储的大小,并相应地分配参数。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">if</span> <span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>合并和填充可训练 (<code class="highlight"><span></span><span class="n">merged_params</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
) 和 fixed (<code class="highlight"><span></span><span class="n">merged_params</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></code>
) 参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">merged_params</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_merge_and_pad_params</span><span class="p">(</span><span class="n">ps</span><span class="p">)</span> <span class="k">for</span> <span class="n">ps</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>计算可训练参数和固定参数的区块大小</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span> <span class="o">=</span> <span class="p">[(</span><span class="nb">len</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="o">//</span> <span class="n">world_size</span> <span class="k">if</span> <span class="n">p</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">merged_params</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>广播尺寸</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">dist</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">),</span> <span class="n">src</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="lineno">129</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>创建一个空张量来接收大小</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="n">chunk_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>收到尺码</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">dist</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="lineno">134</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span> <span class="o">=</span> <span class="n">chunk_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>为要存储在当前设备/节点中的可训练 (<code class="highlight"><span></span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
<code class="highlight"><span></span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></code>
) 和 fixed () 参数创建参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk</span> <span class="o">=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="n">s</span><span class="p">,)),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="n">i</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">)</span>
<span class="lineno">139</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>一个空张量,用于接收可训练参数和固定参数的组合</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">chunk</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">),))</span>
<span class="lineno">143</span>
<span class="lineno">144</span> <span class="k">if</span> <span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>连接可训练参数和固定参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">all_params</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">p</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">world_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">merged_params</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">147</span> <span class="k">del</span> <span class="n">merged_params</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>将它们分散到所有节点/设备</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">dist</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">chunk</span><span class="p">,</span> <span class="nb">list</span><span class="p">(</span><span class="n">all_params</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">))))</span>
<span class="lineno">151</span> <span class="k">del</span> <span class="n">all_params</span>
<span class="lineno">152</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>接收参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">dist</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">chunk</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>收集区块数据</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">chunk</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)</span>
<span class="lineno">158</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">chunk</span><span class="p">):</span>
<span class="lineno">159</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">c</span>
<span class="lineno">160</span> <span class="k">del</span> <span class="n">chunk</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>清理普通参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cleanup_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>添加一个向后钩子。当计算相对于模块的梯度时,会调用该函数。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_ref</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_full_backward_hook</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook</span><span class="p">)</span> <span class="c1"># type: ignore</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<h4>合并所有参数并填充它,使其可被整除<code class="highlight"><span></span><span class="n">world_size</span></code>
</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="k">def</span> <span class="nf">_merge_and_pad_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>参数总数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">size</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>如果它不能被整除<code class="highlight"><span></span><span class="n">world_size</span></code>
,请填充它</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">if</span> <span class="n">size</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">177</span> <span class="n">padding_fixed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">-</span> <span class="p">(</span><span class="n">size</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>否则,无需填充</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">180</span> <span class="n">padding_fixed</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>创建一个空的填充张量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">padding</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="n">padding_fixed</span><span class="p">,))</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>连接所有参数并填充它</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">p</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">padding</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<h3>获取可训练的参数块/分片。</h3>
<p>这就是我们传递给当前节点上的优化器的内容。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="k">def</span> <span class="nf">get_trainable_chunk</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>如果没有可训练的参数,则返回空列表</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">])</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">194</span> <span class="k">return</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>将可训练区块作为列表返回</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<h4>创建给定形状的空张量。</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="k">def</span> <span class="nf">_empty</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<h4>清理参数数据</h4>
<p>这将释放层参数使用的所有内存。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">206</span> <span class="k">def</span> <span class="nf">_cleanup_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>设置标志以指示未读取参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>遍历所有参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="k">for</span> <span class="n">ps</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">:</span>
<span class="lineno">218</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">ps</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>在进行任何新操作之前,请等待对参数的操作完成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>检查以确保该参数不与其他任何内容共享存储</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="k">assert</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">storage_offset</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;The tensor is not the sole occupant of the storage.&quot;</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>将存储空间调整为<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span></span></span></span>。这将释放参数使用的内存。</p>
<p><strong>设置<code class="highlight"><span></span><span class="n">p</span><span class="o">.</span><span class="n">data</span></code>
不会释放内存,因为 autograd 图形会保留对它的引用。</strong></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">storage</span><span class="p">()</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># This is what actually clears the memory</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>确保参数没有梯度数据</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="k">assert</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;Gradients should be None&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<h3>从所有分片中获取参数</h3>
<p>这将从所有节点获取所有参数数据,并在每个节点上重建参数。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">231</span> <span class="k">def</span> <span class="nf">fetch_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>已获取 Skip</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span><span class="p">:</span>
<span class="lineno">240</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>设置旗帜</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>如果没有要获取或共享的内容,请跳过。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="k">if</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">247</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">fetch_stream</span></code>
使用从所有分片中获取参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>创建一个空张量来接收参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">*</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">),))</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>将连续缓冲区拆分为节点数。这些拆分是 “缓冲区” 的视图。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">buffers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">buffer</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>连接可训练和固定区块</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>从所有节点/设备收集参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">260</span> <span class="n">dist</span><span class="o">.</span><span class="n">all_gather</span><span class="p">(</span><span class="n">buffers</span><span class="p">,</span> <span class="n">chunk</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>将收集的参数拆分为可训练的和固定的区块</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="n">params</span> <span class="o">=</span> <span class="n">buffer</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">))</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p>等待收集操作完成,然后清除对缓冲区的引用</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span> <span class="n">buffer</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">266</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">buffers</span><span class="p">:</span>
<span class="lineno">267</span> <span class="n">b</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">268</span> <span class="n">buffer</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">269</span> <span class="k">del</span> <span class="n">buffer</span>
<span class="lineno">270</span> <span class="k">del</span> <span class="n">buffers</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>将可训练和固定参数重塑为连续张量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">273</span> <span class="n">params</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>收集单个参数张量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span> <span class="k">for</span> <span class="n">cont</span><span class="p">,</span> <span class="n">ps</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>如果没有参数,请跳过</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">278</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">ps</span><span class="p">:</span>
<span class="lineno">279</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>连续张量的偏移量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">offset</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>遍历模型参数并分配来自连续张量的值</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">284</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">ps</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>原始参数形状</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">286</span> <span class="n">shape</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">_orig_shape</span> <span class="c1"># type: ignore[attr-defined]</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>更改参数的存储大小。这是我们清理参数<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span></span></span></span>时设置的。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">storage</span><span class="p">()</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>从连续张量中分配值</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">290</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">cont</span><span class="p">[</span><span class="n">offset</span><span class="p">:</span> <span class="n">offset</span> <span class="o">+</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>等待操作完成后才能执行其他操作</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>更新偏移量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="n">offset</span> <span class="o">+=</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>等待操作完成后才能执行其他操作</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">297</span> <span class="n">cont</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">300</span> <span class="k">del</span> <span class="n">params</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<h3>向前传球</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">302</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<p>获取当前节点的所有参数。这被前一层调用,所以这个调用只是为了确保参数被抓取。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">309</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<p>等待参数提取完成。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>开始获取后续层的参数,以便它们将获取当前层进行计算的参数。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">316</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_layer</span><span class="p">:</span>
<span class="lineno">317</span> <span class="n">layer</span><span class="o">.</span><span class="n">fetch_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
如果@@ <p>启用了 autograd则向当前层的参数添加向后挂钩。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">320</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">is_grad_enabled</span><span class="p">():</span>
<span class="lineno">321</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_backward_hooks</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>计算当前图层的输出</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">324</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>清理图层的参数。</p>
<p><em>如果启用了 autograd并且这是网络中的最后一层则跳过清理因为我们需要再次获取参数才能进行反向传递。</em></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">330</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">is_grad_enabled</span><span class="p">()</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_layer</span><span class="p">:</span>
<span class="lineno">331</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cleanup_params</span><span class="p">()</span>
<span class="lineno">332</span>
<span class="lineno">333</span> <span class="k">return</span> <span class="n">res</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<h4>向当前图层的参数添加向后挂钩。</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">335</span> <span class="k">def</span> <span class="nf">_add_backward_hooks</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>添加的向后钩子数量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">341</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>循环浏览当前图层的可训练参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">344</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<p>确保尚未添加挂钩</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">346</span> <span class="k">assert</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s2">&quot;_hook_handle&quot;</span><span class="p">),</span> <span class="s1">&#39;Parameter has already been hooked&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">expand_as</span></code>
用于创建我们可以拦截的 autograd 步骤</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">348</span> <span class="n">p_tmp</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">p</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>获取一个手柄来添加向后钩。<a href="https://amsword.medium.com/understanding-pytorchs-autograd-with-grad-fn-and-next-functions-b2c4836daa00">这篇博客讨论<code class="highlight"><span></span><span class="n">grad_acc</span></code>
</a>了.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">351</span> <span class="n">grad_acc</span> <span class="o">=</span> <span class="n">p_tmp</span><span class="o">.</span><span class="n">grad_fn</span><span class="o">.</span><span class="n">next_functions</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>添加向后挂钩</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">353</span> <span class="n">handle</span> <span class="o">=</span> <span class="n">grad_acc</span><span class="o">.</span><span class="n">register_hook</span><span class="p">(</span>
<span class="lineno">354</span> <span class="n">functools</span><span class="o">.</span><span class="n">partial</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_post_backward_hook</span><span class="p">,</span> <span class="n">p</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>保留对手柄的引用</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">356</span> <span class="n">p</span><span class="o">.</span><span class="n">_hook_handle</span> <span class="o">=</span> <span class="n">handle</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>增加添加的钩子数量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">358</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<h4>处理向后事件</h4>
<p>这被参数反向钩子和模块后向钩子调用。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">360</span> <span class="k">def</span> <span class="nf">_backward_event</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>减少钩子计数器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">368</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">-=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p>如果所有的钩子(包括模块钩子)都被调用了,那么我们可以备份渐变并清理参数。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">372</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
<span class="lineno">373</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backup_grads</span><span class="p">()</span>
<span class="lineno">374</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cleanup_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>开始获取前一图层的参数,因为 autograd 接下来将处理该图层的渐变。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">377</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">prev_layer</span><span class="p">:</span>
<span class="lineno">378</span> <span class="n">layer</span><span class="o">.</span><span class="n">fetch_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-98'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-98'>#</a>
</div>
<h4>参数向后挂钩</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">380</span> <span class="k">def</span> <span class="nf">_post_backward_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-99'>
<div class='docs'>
<div class='section-link'>
<a href='#section-99'>#</a>
</div>
<p>从参数中移除句柄</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">385</span> <span class="n">p</span><span class="o">.</span><span class="n">_hook_handle</span><span class="o">.</span><span class="n">remove</span><span class="p">()</span> <span class="c1"># type: ignore[attr-defined]</span>
<span class="lineno">386</span> <span class="nb">delattr</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s2">&quot;_hook_handle&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-100'>
<div class='docs'>
<div class='section-link'>
<a href='#section-100'>#</a>
</div>
<p>处理向后事件</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">389</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_event</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-101'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-101'>#</a>
</div>
<h4>模块向后挂钩</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">391</span> <span class="k">def</span> <span class="nf">_backward_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-102'>
<div class='docs'>
<div class='section-link'>
<a href='#section-102'>#</a>
</div>
<p>处理向后事件</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">396</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_event</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-103'>
<div class='docs'>
<div class='section-link'>
<a href='#section-103'>#</a>
</div>
<p>上一层将开始计算梯度。我们需要确保它已经完成了参数的获取。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">399</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-104'>
<div class='docs'>
<div class='section-link'>
<a href='#section-104'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">402</span> <span class="k">return</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-105'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-105'>#</a>
</div>
<h3>备份当前图层的渐变</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">404</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">405</span> <span class="k">def</span> <span class="nf">_backup_grads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-106'>
<div class='docs'>
<div class='section-link'>
<a href='#section-106'>#</a>
</div>
<p>如果没有可训练的参数,则跳过</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">410</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">411</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-107'>
<div class='docs'>
<div class='section-link'>
<a href='#section-107'>#</a>
</div>
<p>使用备份流备份渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">414</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-108'>
<div class='docs'>
<div class='section-link'>
<a href='#section-108'>#</a>
</div>
<p>用于存储渐变的缓冲区</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">416</span> <span class="n">buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">],))</span></pre></div>
</div>
</div>
<div class='section' id='section-109'>
<div class='docs'>
<div class='section-link'>
<a href='#section-109'>#</a>
</div>
<p>将连续缓冲区拆分为多个节点。这些拆分是 “缓冲区” 的视图。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">418</span> <span class="n">buffers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">buffer</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-110'>
<div class='docs'>
<div class='section-link'>
<a href='#section-110'>#</a>
</div>
<p>连续缓冲区的偏移量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">421</span> <span class="n">offset</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-111'>
<div class='docs'>
<div class='section-link'>
<a href='#section-111'>#</a>
</div>
<p>遍历可训练的参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">423</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-112'>
<div class='docs'>
<div class='section-link'>
<a href='#section-112'>#</a>
</div>
<p>收集渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">425</span> <span class="n">shape</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">_orig_shape</span> <span class="c1"># type: ignore[attr-defined]</span>
<span class="lineno">426</span> <span class="n">buffer</span><span class="p">[</span><span class="n">offset</span><span class="p">:</span> <span class="n">offset</span> <span class="o">+</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()]</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-113'>
<div class='docs'>
<div class='section-link'>
<a href='#section-113'>#</a>
</div>
<p>更新偏移量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">428</span> <span class="n">offset</span> <span class="o">+=</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-114'>
<div class='docs'>
<div class='section-link'>
<a href='#section-114'>#</a>
</div>
<p>清理渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">430</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-115'>
<div class='docs'>
<div class='section-link'>
<a href='#section-115'>#</a>
</div>
<p>空张量累积当前分片的梯度</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">433</span> <span class="n">grad</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">],))</span></pre></div>
</div>
</div>
<div class='section' id='section-116'>
<div class='docs'>
<div class='section-link'>
<a href='#section-116'>#</a>
</div>
<p>累积每个分片的梯度。它将缓冲区分散到节点上,每个节点累积(减少)它收到的张量。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">436</span> <span class="n">dist</span><span class="o">.</span><span class="n">reduce_scatter</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">buffers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-117'>
<div class='docs'>
<div class='section-link'>
<a href='#section-117'>#</a>
</div>
<p>等待操作完成,然后清除对缓冲区的引用</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">439</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">buffers</span><span class="p">:</span>
<span class="lineno">440</span> <span class="n">b</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">441</span> <span class="n">buffer</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">442</span> <span class="k">del</span> <span class="n">buffer</span>
<span class="lineno">443</span> <span class="k">del</span> <span class="n">buffers</span></pre></div>
</div>
</div>
<div class='section' id='section-118'>
<div class='docs'>
<div class='section-link'>
<a href='#section-118'>#</a>
</div>
<p>设置分块渐变。这就是优化器所看到的。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">446</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span>
<span class="lineno">447</span> <span class="k">del</span> <span class="n">grad</span></pre></div>
</div>
</div>
<div class='section' id='section-119'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-119'>#</a>
</div>
<h2><code class="highlight"><span></span><span class="n">Zero3Layer</span></code>
层的顺序模块</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">450</span><span class="k">class</span> <span class="nc">Zero3Sequential</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-120'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-120'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">modules</span></code>
<code class="highlight"><span></span><span class="n">Zero3Layer</span></code>
图层列表</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">454</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Zero3Layer</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-121'>
<div class='docs'>
<div class='section-link'>
<a href='#section-121'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">458</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-122'>
<div class='docs'>
<div class='section-link'>
<a href='#section-122'>#</a>
</div>
<p>用于获取参数的 CUDA 流</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">461</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-123'>
<div class='docs'>
<div class='section-link'>
<a href='#section-123'>#</a>
</div>
<p>用于备份(累积)梯度的 CUDA 流</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">463</span> <span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-124'>
<div class='docs'>
<div class='section-link'>
<a href='#section-124'>#</a>
</div>
<p>为每个层设置流以及前面和后面的<code class="highlight"><span></span><span class="n">Zero3Layer</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">466</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">modules</span><span class="p">)):</span></pre></div>
</div>
</div>
<div class='section' id='section-125'>
<div class='docs'>
<div class='section-link'>
<a href='#section-125'>#</a>
</div>
<p>设置图层索引</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">468</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">=</span> <span class="n">i</span></pre></div>
</div>
</div>
<div class='section' id='section-126'>
<div class='docs'>
<div class='section-link'>
<a href='#section-126'>#</a>
</div>
<p>设置直播</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">470</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">fetch_stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span>
<span class="lineno">471</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">backup_stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span></pre></div>
</div>
</div>
<div class='section' id='section-127'>
<div class='docs'>
<div class='section-link'>
<a href='#section-127'>#</a>
</div>
<p>设置后续图层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">473</span> <span class="k">if</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">modules</span><span class="p">):</span>
<span class="lineno">474</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">next_layer</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">modules</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-128'>
<div class='docs'>
<div class='section-link'>
<a href='#section-128'>#</a>
</div>
<p>设置前面的图层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">476</span> <span class="k">if</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">477</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">prev_layer</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">modules</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-129'>
<div class='docs'>
<div class='section-link'>
<a href='#section-129'>#</a>
</div>
<p>存储模块清单</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">480</span> <span class="bp">self</span><span class="o">.</span><span class="n">module_list</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">modules</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-130'>
<div class='docs'>
<div class='section-link'>
<a href='#section-130'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">482</span> <span class="k">def</span> <span class="nf">get_trainable_chunk</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-131'>
<div class='docs'>
<div class='section-link'>
<a href='#section-131'>#</a>
</div>
<p>返回每层可训练区块的列表</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">484</span> <span class="k">return</span> <span class="nb">sum</span><span class="p">([</span><span class="n">m</span><span class="o">.</span><span class="n">get_trainable_chunk</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">module_list</span><span class="p">],</span> <span class="p">[])</span></pre></div>
</div>
</div>
<div class='section' id='section-132'>
<div class='docs'>
<div class='section-link'>
<a href='#section-132'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">486</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-133'>
<div class='docs'>
<div class='section-link'>
<a href='#section-133'>#</a>
</div>
<p>确保渐变备份已完成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">488</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-134'>
<div class='docs'>
<div class='section-link'>
<a href='#section-134'>#</a>
</div>
<p>向前传球</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">491</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">module_list</span><span class="p">:</span>
<span class="lineno">492</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-135'>
<div class='docs'>
<div class='section-link'>
<a href='#section-135'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">495</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>