This commit is contained in:
Varuna Jayasiri
2024-06-18 11:09:02 +05:30
parent 999f2036a5
commit cf565bcc1d
9 changed files with 793 additions and 980 deletions

View File

@ -794,11 +794,11 @@
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Set models for saving </p>
<p>Start the experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">242</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_model_savers</span><span class="p">({</span><span class="s1">&#39;info_sets&#39;</span><span class="p">:</span> <span class="n">InfoSetSaver</span><span class="p">(</span><span class="n">conf</span><span class="o">.</span><span class="n">cfr</span><span class="o">.</span><span class="n">info_sets</span><span class="p">)})</span></pre></div>
<div class="highlight"><pre><span class="lineno">242</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
@ -806,11 +806,11 @@
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Start the experiment </p>
<p>Start iterating </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">244</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">244</span> <span class="n">conf</span><span class="o">.</span><span class="n">cfr</span><span class="o">.</span><span class="n">iterate</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
@ -818,24 +818,12 @@
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Start iterating </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="n">conf</span><span class="o">.</span><span class="n">cfr</span><span class="o">.</span><span class="n">iterate</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">251</span> <span class="n">main</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">248</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">249</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -164,39 +164,8 @@
<li><a href="sampling/nucleus.html">Nucleus Sampling</a></li></ul>
<h4><a href="scaling/index.html">Scalable Training/Inference</a></h4>
<ul><li><a href="scaling/zero3/index.html">Zero3 memory optimizations</a></li></ul>
<h2>Highlighted Research Paper PDFs</h2>
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf">ZeRO: Memory Optimizations Toward Training Trillion Parameter Models</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf">PaLM: Scaling Language Modeling with Pathways</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf">Hierarchical Text-Conditional Image Generation with CLIP Latents</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf">STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2112.04426.pdf">Improving language models by retrieving from trillions of tokens</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2003.08934.pdf">NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1706.03762.pdf">Attention Is All You Need</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2006.11239.pdf">Denoising Diffusion Probabilistic Models</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2109.08668.pdf">Primer: Searching for Efficient Transformers for Language Modeling</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1803.02999.pdf">On First-Order Meta-Learning Algorithms</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2103.00020.pdf">Learning Transferable Visual Models From Natural Language Supervision</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2109.02869.pdf">The Sensory Neuron as a Transformer: Permutation-Invariant Neural Networks for Reinforcement Learning</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1805.09801.pdf">Meta-Gradient Reinforcement Learning</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/google_maps_eta.pdf">ETA Prediction with Graph Neural Networks in Google Maps</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/ponder_net.pdf">PonderNet: Learning to Ponder</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/muzero.pdf">Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/gans_n_roses.pdf">GANs N Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!)</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/vit.pdf">An Image is Worth 16X16 Word: Transformers for Image Recognition at Scale</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/resnet.pdf">Deep Residual Learning for Image Recognition</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/distillation.pdf">Distilling the Knowledge in a Neural Network</a></li></ul>
<h3>Installation</h3>
<pre class="highlight lang-bash"><code><span></span>pip<span class="w"> </span>install<span class="w"> </span>labml-nn</code></pre>
<h3>Citing LabML</h3>
<p>If you use this for academic research, please cite it using the following BibTeX entry.</p>
<pre class="highlight lang-bibtex"><code><span></span><span class="nc">@misc</span><span class="p">{</span><span class="nl">labml</span><span class="p">,</span>
<span class="w"> </span><span class="na">author</span><span class="w"> </span><span class="p">=</span><span class="w"> </span><span class="s">{Varuna Jayasiri, Nipun Wijerathne}</span><span class="p">,</span>
<span class="w"> </span><span class="na">title</span><span class="w"> </span><span class="p">=</span><span class="w"> </span><span class="s">{labml.ai Annotated Paper Implementations}</span><span class="p">,</span>
<span class="w"> </span><span class="na">year</span><span class="w"> </span><span class="p">=</span><span class="w"> </span><span class="s">{2020}</span><span class="p">,</span>
<span class="w"> </span><span class="na">url</span><span class="w"> </span><span class="p">=</span><span class="w"> </span><span class="s">{}</span><span class="p">,</span>
<span class="p">}</span></code></pre>
</div>
<div class='code'>

View File

@ -175,8 +175,7 @@
</div>
<p> <a id="TransformerLayer"></a></p>
<h2>Transformer Layer</h2>
<p>This can act as an encoder layer or a decoder layer.</p>
<p>🗒 Some implementations, including the paper seem to have differences in where the layer-normalization is done. Here we do a layer normalization before attention and feed-forward networks, and add the original residual vectors. Alternative is to do a layer normalization after adding the residuals. But we found this to be less stable when training. We found a detailed discussion about this in the paper <a href="https://arxiv.org/abs/2002.04745">On Layer Normalization in the Transformer Architecture</a>.</p>
<p>This can act as an encoder layer or a decoder layer. We use pre-norm.</p>
</div>
<div class='code'>
@ -201,12 +200,12 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">79</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">80</span> <span class="n">self_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">81</span> <span class="n">src_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">82</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">83</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">69</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">70</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">71</span> <span class="n">self_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">72</span> <span class="n">src_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">73</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">74</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -217,16 +216,16 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">92</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
<span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="o">=</span> <span class="n">src_attn</span>
<span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">96</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">98</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">99</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">100</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">82</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">84</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
<span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="o">=</span> <span class="n">src_attn</span>
<span class="lineno">86</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">87</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">88</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">89</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">90</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">91</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -238,7 +237,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">102</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
<div class="highlight"><pre><span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -249,11 +248,11 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">105</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">106</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">107</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">108</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">96</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">97</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">98</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">99</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -265,7 +264,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -277,7 +276,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -289,7 +288,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -301,7 +300,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">if</span> <span class="n">src</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">if</span> <span class="n">src</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -313,7 +312,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">112</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -325,7 +324,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">attn_src</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">attn_src</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -337,7 +336,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn_src</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">116</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn_src</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
@ -349,7 +348,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
@ -361,8 +360,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span><span class="p">:</span>
<span class="lineno">131</span> <span class="bp">self</span><span class="o">.</span><span class="n">ff_input</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">121</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span><span class="p">:</span>
<span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">ff_input</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
@ -374,7 +373,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
@ -386,9 +385,9 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
<span class="lineno">136</span>
<span class="lineno">137</span> <span class="k">return</span> <span class="n">x</span></pre></div>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
<span class="lineno">127</span>
<span class="lineno">128</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
@ -401,7 +400,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span><span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">131</span><span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
@ -412,8 +411,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">148</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">138</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">139</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -425,7 +424,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">141</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -437,7 +436,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -448,7 +447,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">145</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
@ -460,8 +459,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">157</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">147</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">148</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
@ -473,7 +472,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">159</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">150</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
@ -486,7 +485,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span><span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">153</span><span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
@ -497,8 +496,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">170</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">160</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">161</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
@ -510,7 +509,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
@ -522,7 +521,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">165</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
@ -533,7 +532,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">167</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
@ -545,8 +544,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">179</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">tgt_mask</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">169</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">170</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">tgt_mask</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
@ -558,7 +557,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">172</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
@ -573,7 +572,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">175</span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
@ -584,9 +583,9 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">195</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">196</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">185</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">186</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
@ -597,8 +596,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">199</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">189</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">190</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
@ -611,7 +610,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">202</span><span class="k">class</span> <span class="nc">EncoderDecoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">193</span><span class="k">class</span> <span class="nc">EncoderDecoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
@ -622,13 +621,13 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="lineno">210</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">211</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span>
<span class="lineno">212</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span>
<span class="lineno">213</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span>
<span class="lineno">214</span> <span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span> <span class="o">=</span> <span class="n">tgt_embed</span>
<span class="lineno">215</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="lineno">201</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">202</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span>
<span class="lineno">203</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span>
<span class="lineno">204</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span>
<span class="lineno">205</span> <span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span> <span class="o">=</span> <span class="n">tgt_embed</span>
<span class="lineno">206</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
@ -640,9 +639,9 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
<span class="lineno">220</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">221</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">xavier_uniform_</span><span class="p">(</span><span class="n">p</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">210</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
<span class="lineno">211</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">212</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">xavier_uniform_</span><span class="p">(</span><span class="n">p</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
@ -653,7 +652,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">223</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">214</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
@ -665,7 +664,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">225</span> <span class="n">enc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">216</span> <span class="n">enc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
@ -677,7 +676,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">enc</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">218</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">enc</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
@ -688,8 +687,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">229</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">230</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="n">src_mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">220</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">221</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="n">src_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
@ -700,8 +699,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">233</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">(</span><span class="n">tgt</span><span class="p">),</span> <span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">223</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">224</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">(</span><span class="n">tgt</span><span class="p">),</span> <span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -143,47 +143,9 @@ Solving games with incomplete information such as poker with CFR.
#### ✨ [Scalable Training/Inference](scaling/index.html)
* [Zero3 memory optimizations](scaling/zero3/index.html)
## Highlighted Research Paper PDFs
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
* [Training Compute-Optimal Large Language Models](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf)
* [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf)
* [PaLM: Scaling Language Modeling with Pathways](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf)
* [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf)
* [STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf)
* [Improving language models by retrieving from trillions of tokens](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2112.04426.pdf)
* [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2003.08934.pdf)
* [Attention Is All You Need](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1706.03762.pdf)
* [Denoising Diffusion Probabilistic Models](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2006.11239.pdf)
* [Primer: Searching for Efficient Transformers for Language Modeling](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2109.08668.pdf)
* [On First-Order Meta-Learning Algorithms](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1803.02999.pdf)
* [Learning Transferable Visual Models From Natural Language Supervision](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2103.00020.pdf)
* [The Sensory Neuron as a Transformer: Permutation-Invariant Neural Networks for Reinforcement Learning](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2109.02869.pdf)
* [Meta-Gradient Reinforcement Learning](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1805.09801.pdf)
* [ETA Prediction with Graph Neural Networks in Google Maps](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/google_maps_eta.pdf)
* [PonderNet: Learning to Ponder](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/ponder_net.pdf)
* [Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/muzero.pdf)
* [GANs N Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!)](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/gans_n_roses.pdf)
* [An Image is Worth 16X16 Word: Transformers for Image Recognition at Scale](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/vit.pdf)
* [Deep Residual Learning for Image Recognition](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/resnet.pdf)
* [Distilling the Knowledge in a Neural Network](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/distillation.pdf)
### Installation
```bash
pip install labml-nn
```
### Citing LabML
If you use this for academic research, please cite it using the following BibTeX entry.
```bibtex
@misc{labml,
author = {Varuna Jayasiri, Nipun Wijerathne},
title = {labml.ai Annotated Paper Implementations},
year = {2020},
url = {},
}
```
"""

View File

@ -238,8 +238,6 @@ def main():
conf = Configs()
# Load configuration
experiment.configs(conf)
# Set models for saving
experiment.add_model_savers({'info_sets': InfoSetSaver(conf.cfr.info_sets)})
# Start the experiment
with experiment.start():
# Start iterating

View File

@ -168,49 +168,6 @@
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EvI7MtgJ61w5"
},
"source": [
"Set PyTorch models for loading and saving"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"id": "GDlt7dp-5ALt",
"outputId": "d72bab0a-f33d-4f3d-aa4c-3fb9d4443a41"
},
"source": [
"experiment.add_model_savers({'info_sets': InfoSetSaver(conf.cfr.info_sets)})"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<pre style=\"overflow-x: scroll;\">Prepare cfr...\n",
" Prepare create_new_history<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t4.72ms</span>\n",
"Prepare cfr<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t9.13ms</span>\n",
"</pre>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
@ -459,4 +416,4 @@
"outputs": []
}
]
}
}

View File

@ -63,16 +63,7 @@ class TransformerLayer(nn.Module):
## Transformer Layer
This can act as an encoder layer or a decoder layer.
🗒 Some implementations, including the paper seem to have differences
in where the layer-normalization is done.
Here we do a layer normalization before attention and feed-forward networks,
and add the original residual vectors.
Alternative is to do a layer normalization after adding the residuals.
But we found this to be less stable when training.
We found a detailed discussion about this in the paper
[On Layer Normalization in the Transformer Architecture](https://arxiv.org/abs/2002.04745).
This can act as an encoder layer or a decoder layer. We use pre-norm.
"""
def __init__(self, *,

File diff suppressed because it is too large Load Diff

View File

@ -140,59 +140,8 @@ Solving games with incomplete information such as poker with CFR.
#### ✨ [Scalable Training/Inference](https://nn.labml.ai/scaling/index.html)
* [Zero3 memory optimizations](https://nn.labml.ai/scaling/zero3/index.html)
## Highlighted Research Paper PDFs
* [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2205.14135.pdf)
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
* [Training Compute-Optimal Large Language Models](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf)
* [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf)
* [PaLM: Scaling Language Modeling with Pathways](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf)
* [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf)
* [STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf)
* [Improving language models by retrieving from trillions of tokens](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2112.04426.pdf)
* [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2003.08934.pdf)
* [Attention Is All You Need](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1706.03762.pdf)
* [Denoising Diffusion Probabilistic Models](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2006.11239.pdf)
* [Primer: Searching for Efficient Transformers for Language Modeling](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2109.08668.pdf)
* [On First-Order Meta-Learning Algorithms](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1803.02999.pdf)
* [Learning Transferable Visual Models From Natural Language Supervision](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2103.00020.pdf)
* [The Sensory Neuron as a Transformer: Permutation-Invariant Neural Networks for Reinforcement Learning](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2109.02869.pdf)
* [Meta-Gradient Reinforcement Learning](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1805.09801.pdf)
* [ETA Prediction with Graph Neural Networks in Google Maps](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/google_maps_eta.pdf)
* [PonderNet: Learning to Ponder](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/ponder_net.pdf)
* [Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/muzero.pdf)
* [GANs N Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!)](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/gans_n_roses.pdf)
* [An Image is Worth 16X16 Word: Transformers for Image Recognition at Scale](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/vit.pdf)
* [Deep Residual Learning for Image Recognition](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/resnet.pdf)
* [Distilling the Knowledge in a Neural Network](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/distillation.pdf)
### Installation
```bash
pip install labml-nn
```
### Citing
If you use this for academic research, please cite it using the following BibTeX entry.
```bibtex
@misc{labml,
author = {Varuna Jayasiri, Nipun Wijerathne},
title = {labml.ai Annotated Paper Implementations},
year = {2020},
url = {https://nn.labml.ai/},
}
```
### Other Projects
#### [🚀 Trending Research Papers](https://papers.labml.ai/)
This shows the most popular research papers on social media. It also aggregates links to useful resources like paper explanations videos and discussions.
#### [🧪 labml.ai/labml](https://github.com/labmlai/labml)
This is a library that let's you monitor deep learning model training and hardware usage from your mobile phone. It also comes with a bunch of other tools to help write deep learning code efficiently.