Files
Varuna Jayasiri 2038b11d29 ja translation
2023-05-10 17:00:29 -04:00

1851 lines
119 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="ja">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="これは PyTorch で書かれたゼロ DP メモリ最適化の実装です。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="ゼロ DP メモリ最適化"/>
<meta name="twitter:description" content="これは PyTorch で書かれたゼロ DP メモリ最適化の実装です。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/scaling/zero3/index.html"/>
<meta property="og:title" content="ゼロ DP メモリ最適化"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="ゼロ DP メモリ最適化"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="ゼロ DP メモリ最適化"/>
<meta property="og:description" content="これは PyTorch で書かれたゼロ DP メモリ最適化の実装です。"/>
<title>ゼロ DP メモリ最適化</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/scaling/zero3/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">scaling</a>
<a class="parent" href="index.html">zero3</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/scaling/zero3/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>ゼロ DP メモリ最適化</h1>
<p>これは、論文「<a href="https://papers.labml.ai/paper/1910.02054">Zero: 1兆のパラメーターモデルのトレーニングに向けたメモリ最適化」で紹介されているゼロDPの実装です</a></p>
<p>オプティマイザの状態、グラデーション、パラメータの断片を複数のデバイス/ノードに保持します。これにより、メモリ消費量が元のモデルと同じになります。ここで、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style="">Ψ</span></span></span></span></span></span>はパラメーターの数、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>はシャードの数、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span></span></span></span></span>パラメーターごとのオプティマイザーのバイト数です。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.4608599999999998em;vertical-align:-0.4508599999999999em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.01em;"><span style="top:-2.655em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqg" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:-0.10903em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.485em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mtight" style="">2</span><span class="mbin mtight" style="">+</span><span class="mord mtight" style="">2</span></span><span class="mbin mtight">+</span><span class="mord mtight coloredeq eqi" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span></span><span class="mclose mtight">)</span><span class="mord mtight coloredeq eqf" style=""><span class="mord mtight" style="">Ψ</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.4508599999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style="">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">2</span></span></span></span></span></span>は、16 ビットの精度を前提としたパラメーターとグラデーションのメモリです。つまり、パラメーターとグラデーションごとに 2 バイトです。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">12</span></span></span></span></span>Adam オプティマイザー用です。これは、パラメーターのコピーと fp32 のパラメーターごとに 2 つのモーメントを保持しているためです</p>
<p>ゼロDPの通信量は。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.02778em;">O</span><span class="mopen">(</span><span class="mord">3</span><span class="mord coloredeq eqf" style=""><span class="mord" style="">Ψ</span></span><span class="mclose">)</span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.02778em;">O</span><span class="mopen">(</span><span class="mord">2</span><span class="mord coloredeq eqf" style=""><span class="mord" style="">Ψ</span></span><span class="mclose">)</span></span></span></span></span>比較のためにデータ並行トレーニングの通信量は</p>.
<p>これは名前が付けられていますが<code class="highlight"><span></span><span class="n">Zero3</span></code>
、残留メモリ消費を対象とするゼロRメモリ最適化は実装しておらず、DPがゼロの部分のみを実装しています。この実装では、パラメータのサブセットのみのトレーニングをサポートしています</p>
<p><a href="https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html">この実装はフェアスケールFSDPに触発されています</a></p>
<p><a href="finetune_neox.html">ゼロDPメモリ最適化を使用してGPT NeoXを微調整するスクリプトを次に示します</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span></span><span class="kn">import</span> <span class="nn">functools</span>
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="lineno">34</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">36</span><span class="kn">import</span> <span class="nn">torch.distributed</span> <span class="k">as</span> <span class="nn">dist</span>
<span class="lineno">37</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>ゼロ 3 レイヤー</h2>
<p>モデルの各レイヤー(またはいくつかの連続したレイヤーの組み合わせ)をこのモジュールでラップする必要があります。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span><span class="k">class</span> <span class="nc">Zero3Layer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">chunk</span></code>
各シャードはパラメータをリストに保持します。<code class="highlight"><span></span><span class="n">chunk</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
はトレーニング可能なパラメーター用で、<code class="highlight"><span></span><span class="n">chunk</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></code>
固定パラメーター用です</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">chunk</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">chunk</span></code>
これはリスト内のチャンクのサイズです。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="n">chunk_size</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>最初のチャンクはトレーニング可能なパラメーター用です。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">TRAINING_PARAMS_IDX</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>これは、トレーニング可能なパラメーターと固定パラメーターとしてリストに分割されたパラメーターのリストです。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">param_refs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>パラメータを取得する CUDA ストリーム</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">fetch_stream</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>勾配をバックアップ/蓄積するためのCUDAストリーム</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">backup_stream</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>このレイヤーの直前のレイヤーのリスト</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">prev_layer</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s1">&#39;Zero3Layer&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>このレイヤーの直後のレイヤーのリスト</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">next_layer</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s1">&#39;Zero3Layer&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>現在のレイヤーの位置。これをログのデバッグに使用しました</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="n">layer_idx</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>パラメータが取得されているかどうか</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">is_fetched</span><span class="p">:</span> <span class="nb">bool</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>レイヤーのデバイス</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>レイヤーのデータタイプ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>ラップするモジュール</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>データがシャーディングされるノード/デバイスの数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">world_size</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">module</span></code>
ラップするモジュール。</li>
<li><code class="highlight"><span></span><span class="n">rank</span></code>
現在のノードのランク。</li>
<li><code class="highlight"><span></span><span class="n">world_size</span></code>
データがシャーディングされるノード/デバイスの数。</li>
<li><code class="highlight"><span></span><span class="n">device</span></code>
レイヤーのデバイス。</li>
<li><code class="highlight"><span></span><span class="n">dtype</span></code>
レイヤーのデータタイプ。</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">world_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>プロパティを初期化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">module</span>
<span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">prev_layer</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">96</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_layer</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span> <span class="o">=</span> <span class="kc">False</span>
<span class="lineno">98</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="n">world_size</span>
<span class="lineno">99</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
<span class="lineno">100</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">102</span>
<span class="lineno">103</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>レイヤーのすべてのパラメーターを収集します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">all_param_refs</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">()]</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>パラメータの形状を保存しておきます。後で再構築する必要があるからです。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span><span class="p">:</span>
<span class="lineno">109</span> <span class="n">p</span><span class="o">.</span><span class="n">_orig_shape</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>すべてのパラメータは同じタイプでなければなりません</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span><span class="p">:</span>
<span class="lineno">113</span> <span class="k">assert</span> <span class="n">p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">dtype</span><span class="p">,</span> <span class="s2">&quot;All parameters should have same dtype&quot;</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>トレーニング可能なパラメータと固定パラメータを分離</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span> <span class="o">=</span> <span class="p">[[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">],</span>
<span class="lineno">117</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">all_param_refs</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">p</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">]]</span>
<span class="lineno">118</span> <span class="k">del</span> <span class="n">all_param_refs</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">rank</span> <span class="o">=</span> <span class="mi">0</span></code>
ノードは、各デバイス/ノードが保存するサイズを計算し、それに応じてパラメータを分散します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">if</span> <span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>トレーニング可能 (<code class="highlight"><span></span><span class="n">merged_params</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
) パラメーターと固定 (<code class="highlight"><span></span><span class="n">merged_params</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></code>
) パラメーターをマージしてパディングする</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">merged_params</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_merge_and_pad_params</span><span class="p">(</span><span class="n">ps</span><span class="p">)</span> <span class="k">for</span> <span class="n">ps</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>トレーニング可能なパラメータと固定パラメータのチャンクサイズを計算します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span> <span class="o">=</span> <span class="p">[(</span><span class="nb">len</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="o">//</span> <span class="n">world_size</span> <span class="k">if</span> <span class="n">p</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">merged_params</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>サイズをブロードキャスト</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">dist</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">),</span> <span class="n">src</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="lineno">129</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>空のテンソルを作成してサイズを受け取る</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="n">chunk_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>サイズを受け取る</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">dist</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="lineno">134</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span> <span class="o">=</span> <span class="n">chunk_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>trainable (<code class="highlight"><span></span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
) パラメーターと fixed () パラメーターのパラメーターを作成して、現在のデバイス/ノードに保存します <code class="highlight"><span></span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk</span> <span class="o">=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="n">s</span><span class="p">,)),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="n">i</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">)</span>
<span class="lineno">139</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>トレーニング可能なパラメーターと固定パラメーターを組み合わせて受け取る空のテンソル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">chunk</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">),))</span>
<span class="lineno">143</span>
<span class="lineno">144</span> <span class="k">if</span> <span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>トレーニング可能なパラメータと固定パラメータの両方を連結する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">all_params</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">p</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">world_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">merged_params</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">147</span> <span class="k">del</span> <span class="n">merged_params</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>それらをすべてのノード/デバイスに散らす</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">dist</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">chunk</span><span class="p">,</span> <span class="nb">list</span><span class="p">(</span><span class="n">all_params</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">))))</span>
<span class="lineno">151</span> <span class="k">del</span> <span class="n">all_params</span>
<span class="lineno">152</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>パラメータを受け取る</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">dist</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">chunk</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>チャンクデータを収集する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">chunk</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)</span>
<span class="lineno">158</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">chunk</span><span class="p">):</span>
<span class="lineno">159</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">c</span>
<span class="lineno">160</span> <span class="k">del</span> <span class="n">chunk</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>標準パラメータをクリーンアップ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cleanup_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>後方フックを追加。これは、モジュールを基準としたグラデーションが計算されるときに呼び出されます</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_ref</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_full_backward_hook</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook</span><span class="p">)</span> <span class="c1"># type: ignore</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<h4>すべてのパラメータを結合し、<code class="highlight"><span></span><span class="n">world_size</span></code>
次の数で割り切れるようにパディングします。</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="k">def</span> <span class="nf">_merge_and_pad_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>パラメータの総数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">size</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">world_size</span></code>
割り切れない場合はパディングしてください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">if</span> <span class="n">size</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">177</span> <span class="n">padding_fixed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">-</span> <span class="p">(</span><span class="n">size</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>それ以外の場合は、パッドする必要はありません</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">180</span> <span class="n">padding_fixed</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>空のパディングテンソルの作成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">padding</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="n">padding_fixed</span><span class="p">,))</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>すべてのパラメータを連結してパディングします。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">p</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">padding</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<h3>トレーニング可能なパラメータのチャンク/シャードを取得します。</h3>
<p>これを現在のノードのオプティマイザーに渡します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="k">def</span> <span class="nf">get_trainable_chunk</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>トレーニング可能なパラメータがない場合はリストを返して空にする</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">])</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">194</span> <span class="k">return</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>トレーニング可能なチャンクをリストとして返す</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<h4>与えられた形状の空のテンソルを作成します。</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="k">def</span> <span class="nf">_empty</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<h4>パラメータデータをクリーンアップする</h4>
<p>これにより、レイヤーパラメーターで使用されていたすべてのメモリが解放されます。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">206</span> <span class="k">def</span> <span class="nf">_cleanup_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>パラメータがフェッチされないことを示すフラグを設定します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>すべてのパラメータを反復処理</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="k">for</span> <span class="n">ps</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">:</span>
<span class="lineno">218</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">ps</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>パラメータの操作が完了するまで待ってから、新しい操作を行います。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>パラメータが他のものとストレージを共有していないことを確認してください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="k">assert</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">storage_offset</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;The tensor is not the sole occupant of the storage.&quot;</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>ストレージのサイズをに変更します。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span></span></span></span>これにより、パラメータが使用していたメモリが解放されます。</p>
<p><strong>autograd <code class="highlight"><span></span><span class="n">p</span><span class="o">.</span><span class="n">data</span></code>
グラフはメモリへの参照を保持するので、設定してもメモリは解放されません。</strong></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">storage</span><span class="p">()</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># This is what actually clears the memory</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>パラメータに勾配データがないことを確認してください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="k">assert</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;Gradients should be None&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<h3>すべてのシャードからパラメータを取得</h3>
<p>これにより、すべてのノードからすべてのパラメータデータが取得され、各ノードのパラメータが再構築されます。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">231</span> <span class="k">def</span> <span class="nf">fetch_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>スキップはすでに取得されています</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span><span class="p">:</span>
<span class="lineno">240</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>フラグを設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_fetched</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>取得または共有するものがない場合はスキップしてください。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="k">if</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">247</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">fetch_stream</span></code>
を使用してすべてのシャードからパラメータを取得します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>空のテンソルを作成してパラメーターを受け取る</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">*</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">),))</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>連続バッファをード数に分割します。これらの分割は「buffer」のビューです</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">buffers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">buffer</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>トレーニング可能なチャンクと固定チャンクの両方を連結する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>すべてのノード/デバイスからパラメータを収集します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">260</span> <span class="n">dist</span><span class="o">.</span><span class="n">all_gather</span><span class="p">(</span><span class="n">buffers</span><span class="p">,</span> <span class="n">chunk</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>収集したパラメーターをトレーニング可能なチャンクと固定チャンクに分割します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="n">params</span> <span class="o">=</span> <span class="n">buffer</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">))</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p>収集操作が完了するのを待ってから、バッファへの参照をクリアします。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span> <span class="n">buffer</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">266</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">buffers</span><span class="p">:</span>
<span class="lineno">267</span> <span class="n">b</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">268</span> <span class="n">buffer</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">269</span> <span class="k">del</span> <span class="n">buffer</span>
<span class="lineno">270</span> <span class="k">del</span> <span class="n">buffers</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>トレーニング可能なパラメーターと固定パラメーターを連続テンソルにリシェイプ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">273</span> <span class="n">params</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>個々のパラメータテンソルを収集する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span> <span class="k">for</span> <span class="n">cont</span><span class="p">,</span> <span class="n">ps</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>パラメータがない場合はスキップしてください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">278</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">ps</span><span class="p">:</span>
<span class="lineno">279</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>連続テンソルのオフセット</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">offset</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>モデルパラメーターを繰り返し処理し、連続テンソルから値を割り当てます</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">284</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">ps</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>オリジナルのパラメータシェイプ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">286</span> <span class="n">shape</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">_orig_shape</span> <span class="c1"># type: ignore[attr-defined]</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>パラメータのストレージサイズを変更します。これは、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span></span></span></span>パラメータをクリーンアップしたときに設定されました。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">storage</span><span class="p">()</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>連続テンソルから値を割り当てます</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">290</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">cont</span><span class="p">[</span><span class="n">offset</span><span class="p">:</span> <span class="n">offset</span> <span class="o">+</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>操作が完了するまで待ってから、他の操作を実行してください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>オフセットの更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="n">offset</span> <span class="o">+=</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>操作が完了するまで待ってから、他の操作を実行してください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">297</span> <span class="n">cont</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">300</span> <span class="k">del</span> <span class="n">params</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<h3>フォワードパス</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">302</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<p>現在のノードのすべてのパラメータを取得します。これは前のレイヤーから呼び出されるので、この呼び出しはパラメーターがフェッチされていることを確認するためだけのものです</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">309</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<p>パラメータの取得が完了するまでお待ちください。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>処理中のレイヤーのパラメーターの取得を開始します。そうすれば、現在のレイヤーが計算を行うパラメーターが取得されます。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">316</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_layer</span><span class="p">:</span>
<span class="lineno">317</span> <span class="n">layer</span><span class="o">.</span><span class="n">fetch_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>autograd が有効になっている場合は、現在のレイヤーのパラメーターに逆方向フックを追加します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">320</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">is_grad_enabled</span><span class="p">():</span>
<span class="lineno">321</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_backward_hooks</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>現在のレイヤーの出力を計算します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">324</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>レイヤーのパラメーターをクリーンアップします。</p>
<p><em>autograd が有効になっていて、これがネットワークの最後のレイヤーである場合は、後方パスのパラメーターを再度取得する必要があるため、クリーンアップをスキップしてください。</em></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">330</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">is_grad_enabled</span><span class="p">()</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_layer</span><span class="p">:</span>
<span class="lineno">331</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cleanup_params</span><span class="p">()</span>
<span class="lineno">332</span>
<span class="lineno">333</span> <span class="k">return</span> <span class="n">res</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<h4>現在のレイヤーのパラメーターに逆方向フックを追加します。</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">335</span> <span class="k">def</span> <span class="nf">_add_backward_hooks</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>追加された後方フックの数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">341</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>現在のレイヤーのトレーニング可能なパラメーターをループスルーする</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">344</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<p>フックがまだ追加されていないことを確認してください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">346</span> <span class="k">assert</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s2">&quot;_hook_handle&quot;</span><span class="p">),</span> <span class="s1">&#39;Parameter has already been hooked&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">expand_as</span></code>
インターセプトできるオートグラードのステップを作るのに使う</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">348</span> <span class="n">p_tmp</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">p</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>後ろ向きフックを取り付けるためのハンドルを用意してください。<a href="https://amsword.medium.com/understanding-pytorchs-autograd-with-grad-fn-and-next-functions-b2c4836daa00">このブログでは、. について説明します<code class="highlight"><span></span><span class="n">grad_acc</span></code>
</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">351</span> <span class="n">grad_acc</span> <span class="o">=</span> <span class="n">p_tmp</span><span class="o">.</span><span class="n">grad_fn</span><span class="o">.</span><span class="n">next_functions</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>後方フックを追加</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">353</span> <span class="n">handle</span> <span class="o">=</span> <span class="n">grad_acc</span><span class="o">.</span><span class="n">register_hook</span><span class="p">(</span>
<span class="lineno">354</span> <span class="n">functools</span><span class="o">.</span><span class="n">partial</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_post_backward_hook</span><span class="p">,</span> <span class="n">p</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>ハンドルへの参照を忘れずに</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">356</span> <span class="n">p</span><span class="o">.</span><span class="n">_hook_handle</span> <span class="o">=</span> <span class="n">handle</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>追加するフックの数を増やしてください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">358</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<h4>逆方向イベントの処理</h4>
<p>これは、パラメーターの逆方向フックとモジュールの逆方向フックによって呼び出されます。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">360</span> <span class="k">def</span> <span class="nf">_backward_event</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>フックカウンターをデクリメントしてください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">368</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">-=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p>すべてのフック (モジュールフックを含む) が呼び出されたら、グラデーションをバックアップしてパラメータをクリーンアップできます。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">372</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_hook_handles</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
<span class="lineno">373</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backup_grads</span><span class="p">()</span>
<span class="lineno">374</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cleanup_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>autograd が次にそのレイヤーのグラデーションを処理するので、前のレイヤーのパラメーターの取得を開始します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">377</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">prev_layer</span><span class="p">:</span>
<span class="lineno">378</span> <span class="n">layer</span><span class="o">.</span><span class="n">fetch_params</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-98'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-98'>#</a>
</div>
<h4>パラメータバックワードフック</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">380</span> <span class="k">def</span> <span class="nf">_post_backward_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-99'>
<div class='docs'>
<div class='section-link'>
<a href='#section-99'>#</a>
</div>
<p>パラメータからハンドルを削除します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">385</span> <span class="n">p</span><span class="o">.</span><span class="n">_hook_handle</span><span class="o">.</span><span class="n">remove</span><span class="p">()</span> <span class="c1"># type: ignore[attr-defined]</span>
<span class="lineno">386</span> <span class="nb">delattr</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s2">&quot;_hook_handle&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-100'>
<div class='docs'>
<div class='section-link'>
<a href='#section-100'>#</a>
</div>
<p>逆方向イベントの処理</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">389</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_event</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-101'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-101'>#</a>
</div>
<h4>モジュール後方フック</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">391</span> <span class="k">def</span> <span class="nf">_backward_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-102'>
<div class='docs'>
<div class='section-link'>
<a href='#section-102'>#</a>
</div>
<p>逆方向イベントの処理</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">396</span> <span class="bp">self</span><span class="o">.</span><span class="n">_backward_event</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-103'>
<div class='docs'>
<div class='section-link'>
<a href='#section-103'>#</a>
</div>
<p>前のレイヤーがグラデーションの計算を開始します。パラメータの取得が完了したことを確認する必要があります</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">399</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-104'>
<div class='docs'>
<div class='section-link'>
<a href='#section-104'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">402</span> <span class="k">return</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-105'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-105'>#</a>
</div>
<h3>現在のレイヤーのグラデーションをバックアップします</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">404</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">405</span> <span class="k">def</span> <span class="nf">_backup_grads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-106'>
<div class='docs'>
<div class='section-link'>
<a href='#section-106'>#</a>
</div>
<p>トレーニング可能なパラメータがない場合はスキップ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">410</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">411</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-107'>
<div class='docs'>
<div class='section-link'>
<a href='#section-107'>#</a>
</div>
<p>バックアップストリームを使用してグラデーションをバックアップします</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">414</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-108'>
<div class='docs'>
<div class='section-link'>
<a href='#section-108'>#</a>
</div>
<p>グラデーションを保存するバッファ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">416</span> <span class="n">buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">],))</span></pre></div>
</div>
</div>
<div class='section' id='section-109'>
<div class='docs'>
<div class='section-link'>
<a href='#section-109'>#</a>
</div>
<p>連続バッファを複数のードに分割します。これらの分割は「buffer」のビューです</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">418</span> <span class="n">buffers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">buffer</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-110'>
<div class='docs'>
<div class='section-link'>
<a href='#section-110'>#</a>
</div>
<p>連続バッファのオフセット</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">421</span> <span class="n">offset</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-111'>
<div class='docs'>
<div class='section-link'>
<a href='#section-111'>#</a>
</div>
<p>トレーニング可能なパラメーターを繰り返し処理</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">423</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_refs</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-112'>
<div class='docs'>
<div class='section-link'>
<a href='#section-112'>#</a>
</div>
<p>グラデーションを集める</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">425</span> <span class="n">shape</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">_orig_shape</span> <span class="c1"># type: ignore[attr-defined]</span>
<span class="lineno">426</span> <span class="n">buffer</span><span class="p">[</span><span class="n">offset</span><span class="p">:</span> <span class="n">offset</span> <span class="o">+</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()]</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-113'>
<div class='docs'>
<div class='section-link'>
<a href='#section-113'>#</a>
</div>
<p>オフセットの更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">428</span> <span class="n">offset</span> <span class="o">+=</span> <span class="n">shape</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-114'>
<div class='docs'>
<div class='section-link'>
<a href='#section-114'>#</a>
</div>
<p>グラデーションをきれいにする</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">430</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-115'>
<div class='docs'>
<div class='section-link'>
<a href='#section-115'>#</a>
</div>
<p>テンソルを空にすると、現在のシャードの勾配が蓄積されます</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">433</span> <span class="n">grad</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">],))</span></pre></div>
</div>
</div>
<div class='section' id='section-116'>
<div class='docs'>
<div class='section-link'>
<a href='#section-116'>#</a>
</div>
<p>各シャードのグラデーションを累積します。バッファをノード全体に分散させ、各ノードは受け取ったテンソルを蓄積(削減)します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">436</span> <span class="n">dist</span><span class="o">.</span><span class="n">reduce_scatter</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">buffers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-117'>
<div class='docs'>
<div class='section-link'>
<a href='#section-117'>#</a>
</div>
<p>操作が完了するのを待ってから、バッファへの参照をクリアします。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">439</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">buffers</span><span class="p">:</span>
<span class="lineno">440</span> <span class="n">b</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">441</span> <span class="n">buffer</span><span class="o">.</span><span class="n">record_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span><span class="p">)</span>
<span class="lineno">442</span> <span class="k">del</span> <span class="n">buffer</span>
<span class="lineno">443</span> <span class="k">del</span> <span class="n">buffers</span></pre></div>
</div>
</div>
<div class='section' id='section-118'>
<div class='docs'>
<div class='section-link'>
<a href='#section-118'>#</a>
</div>
<p>チャンクのグラデーションを設定します。これがオプティマイザーの見解です</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">446</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">TRAINING_PARAMS_IDX</span><span class="p">]</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span>
<span class="lineno">447</span> <span class="k">del</span> <span class="n">grad</span></pre></div>
</div>
</div>
<div class='section' id='section-119'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-119'>#</a>
</div>
<h2><code class="highlight"><span></span><span class="n">Zero3Layer</span></code>
レイヤー用シーケンシャルモジュール</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">450</span><span class="k">class</span> <span class="nc">Zero3Sequential</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-120'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-120'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">modules</span></code>
<code class="highlight"><span></span><span class="n">Zero3Layer</span></code>
レイヤーのリスト</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">454</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Zero3Layer</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-121'>
<div class='docs'>
<div class='section-link'>
<a href='#section-121'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">458</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-122'>
<div class='docs'>
<div class='section-link'>
<a href='#section-122'>#</a>
</div>
<p>パラメータを取得するための CUDA ストリーム</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">461</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-123'>
<div class='docs'>
<div class='section-link'>
<a href='#section-123'>#</a>
</div>
<p>グラデーションをバックアップ蓄積するCUDAストリーム</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">463</span> <span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-124'>
<div class='docs'>
<div class='section-link'>
<a href='#section-124'>#</a>
</div>
<p>各レイヤーのストリームと先行レイヤーと後続レイヤーを設定します <code class="highlight"><span></span><span class="n">Zero3Layer</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">466</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">modules</span><span class="p">)):</span></pre></div>
</div>
</div>
<div class='section' id='section-125'>
<div class='docs'>
<div class='section-link'>
<a href='#section-125'>#</a>
</div>
<p>レイヤーインデックスを設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">468</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">=</span> <span class="n">i</span></pre></div>
</div>
</div>
<div class='section' id='section-126'>
<div class='docs'>
<div class='section-link'>
<a href='#section-126'>#</a>
</div>
<p>ストリームを設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">470</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">fetch_stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fetch_stream</span>
<span class="lineno">471</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">backup_stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span></pre></div>
</div>
</div>
<div class='section' id='section-127'>
<div class='docs'>
<div class='section-link'>
<a href='#section-127'>#</a>
</div>
<p>進行中のレイヤーを設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">473</span> <span class="k">if</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">modules</span><span class="p">):</span>
<span class="lineno">474</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">next_layer</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">modules</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-128'>
<div class='docs'>
<div class='section-link'>
<a href='#section-128'>#</a>
</div>
<p>前のレイヤーを設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">476</span> <span class="k">if</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">477</span> <span class="n">modules</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">prev_layer</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">modules</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-129'>
<div class='docs'>
<div class='section-link'>
<a href='#section-129'>#</a>
</div>
<p>モジュール一覧を保存</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">480</span> <span class="bp">self</span><span class="o">.</span><span class="n">module_list</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">modules</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-130'>
<div class='docs'>
<div class='section-link'>
<a href='#section-130'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">482</span> <span class="k">def</span> <span class="nf">get_trainable_chunk</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-131'>
<div class='docs'>
<div class='section-link'>
<a href='#section-131'>#</a>
</div>
<p>各レイヤーからトレーニング可能なチャンクのリストを返します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">484</span> <span class="k">return</span> <span class="nb">sum</span><span class="p">([</span><span class="n">m</span><span class="o">.</span><span class="n">get_trainable_chunk</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">module_list</span><span class="p">],</span> <span class="p">[])</span></pre></div>
</div>
</div>
<div class='section' id='section-132'>
<div class='docs'>
<div class='section-link'>
<a href='#section-132'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">486</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-133'>
<div class='docs'>
<div class='section-link'>
<a href='#section-133'>#</a>
</div>
<p>グラデーションのバックアップが完了していることを確認してください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">488</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">backup_stream</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-134'>
<div class='docs'>
<div class='section-link'>
<a href='#section-134'>#</a>
</div>
<p>フォワードパス</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">491</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">module_list</span><span class="p">:</span>
<span class="lineno">492</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-135'>
<div class='docs'>
<div class='section-link'>
<a href='#section-135'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">495</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>