stochastic deptch

This commit is contained in:
Varuna Jayasiri
2021-06-07 15:08:55 +05:30
parent ed681b48b9
commit 194adc6d68
2 changed files with 48 additions and 44 deletions

View File

@ -68,17 +68,19 @@
<a href='#section-0'>#</a>
</div>
<h1><a href="index.html">Pay Attention to MLPs (gMLP)</a> Experiment</h1>
<p>This is an annotated PyTorch experiment to train a <a href="index.html">gMLP model</a>.</p>
<p>This is an annotated PyTorch experiment to train a <a href="index.html">gMLP model</a>.
The paper also applies a Stochastic Depth regularization where some layers are removed randomly during training.
We have not implemented that here.</p>
<p>This is based on
<a href="../basic/autoregressive_experiment.html">training loop and configurations for a simple transformer auto-regressive NLP task</a>.</p>
<p><a href="https://app.labml.ai/run/01bd941ac74c11eb890c1d9196651a4a"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">16</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.basic.autoregressive_experiment</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">BasicAutoRegressionConfigs</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.gmlp</span> <span class="kn">import</span> <span class="n">GMLPBlock</span></pre></div>
<div class="highlight"><pre><span class="lineno">18</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.basic.autoregressive_experiment</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">BasicAutoRegressionConfigs</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.gmlp</span> <span class="kn">import</span> <span class="n">GMLPBlock</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -91,7 +93,7 @@
<a href="../basic/autoregressive_transformer.html">training loop and configurations for a simple transformer auto-regressive NLP task</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">23</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">BasicAutoRegressionConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">25</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">BasicAutoRegressionConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -102,7 +104,7 @@
<p>Transformer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">TransformerConfigs</span> <span class="o">=</span> <span class="s1">&#39;gMLP&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">34</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">TransformerConfigs</span> <span class="o">=</span> <span class="s1">&#39;gMLP&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -113,7 +115,7 @@
<p>gMLP Block</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="n">gmlp</span><span class="p">:</span> <span class="n">GMLPBlock</span></pre></div>
<div class="highlight"><pre><span class="lineno">36</span> <span class="n">gmlp</span><span class="p">:</span> <span class="n">GMLPBlock</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -124,7 +126,7 @@
<p><code>d_ffn</code> for gMLP projection layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">36</span> <span class="n">d_ffn</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span></pre></div>
<div class="highlight"><pre><span class="lineno">38</span> <span class="n">d_ffn</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -135,8 +137,8 @@
<h3>Create a gMLP block</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">gmlp</span><span class="p">,</span> <span class="s1">&#39;gMLP&#39;</span><span class="p">)</span>
<span class="lineno">40</span><span class="k">def</span> <span class="nf">_gmlp_configs</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">41</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">gmlp</span><span class="p">,</span> <span class="s1">&#39;gMLP&#39;</span><span class="p">)</span>
<span class="lineno">42</span><span class="k">def</span> <span class="nf">_gmlp_configs</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -147,7 +149,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">return</span> <span class="n">GMLPBlock</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ffn</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">46</span> <span class="k">return</span> <span class="n">GMLPBlock</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ffn</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -158,8 +160,8 @@
<h3>Transformer configurations</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">,</span> <span class="s1">&#39;gMLP&#39;</span><span class="p">)</span>
<span class="lineno">48</span><span class="k">def</span> <span class="nf">_transformer_configs</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">49</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">,</span> <span class="s1">&#39;gMLP&#39;</span><span class="p">)</span>
<span class="lineno">50</span><span class="k">def</span> <span class="nf">_transformer_configs</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -171,7 +173,7 @@
<a href="../configs.html#TransformerConfigs">configurable transformer implementation</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">TransformerConfigs</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">TransformerConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -182,8 +184,8 @@
<p>Set the vocabulary sizes for embeddings and generating logits</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_src_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
<span class="lineno">58</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_tgt_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span></pre></div>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_src_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
<span class="lineno">60</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_tgt_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -194,7 +196,7 @@
<p>Set model size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">conf</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span></pre></div>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">conf</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -205,9 +207,9 @@
<p>Replace the encoder layer with a gMLP layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">conf</span><span class="o">.</span><span class="n">encoder_layer</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">gmlp</span>
<span class="lineno">63</span>
<span class="lineno">64</span> <span class="k">return</span> <span class="n">conf</span></pre></div>
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">conf</span><span class="o">.</span><span class="n">encoder_layer</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">gmlp</span>
<span class="lineno">65</span>
<span class="lineno">66</span> <span class="k">return</span> <span class="n">conf</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -218,7 +220,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">69</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -229,7 +231,7 @@
<p>Create experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;gMLP&quot;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;gMLP&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -240,7 +242,7 @@
<p>Create configs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -251,7 +253,7 @@
<p>Override configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -262,7 +264,7 @@
<p>Use character level tokenizer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="s1">&#39;tokenizer&#39;</span><span class="p">:</span> <span class="s1">&#39;character&#39;</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">77</span> <span class="s1">&#39;tokenizer&#39;</span><span class="p">:</span> <span class="s1">&#39;character&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -273,7 +275,7 @@
<p>Prompt separator is blank</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="s1">&#39;prompt_separator&#39;</span><span class="p">:</span> <span class="s1">&#39;&#39;</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">79</span> <span class="s1">&#39;prompt_separator&#39;</span><span class="p">:</span> <span class="s1">&#39;&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -284,7 +286,7 @@
<p>Starting prompt for sampling</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="s1">&#39;prompt&#39;</span><span class="p">:</span> <span class="s1">&#39;It is &#39;</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">81</span> <span class="s1">&#39;prompt&#39;</span><span class="p">:</span> <span class="s1">&#39;It is &#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
@ -295,7 +297,7 @@
<p>Use Tiny Shakespeare dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="s1">&#39;text&#39;</span><span class="p">:</span> <span class="s1">&#39;tiny_shakespeare&#39;</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">83</span> <span class="s1">&#39;text&#39;</span><span class="p">:</span> <span class="s1">&#39;tiny_shakespeare&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
@ -306,7 +308,7 @@
<p>Use a context size of $256$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">86</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
@ -317,7 +319,7 @@
<p>Train for $128$ epochs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">88</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
@ -328,7 +330,7 @@
<p>Batch size $32$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">90</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
@ -340,7 +342,7 @@
per epoch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">93</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
@ -351,8 +353,8 @@ per epoch</p>
<p>Model size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="s1">&#39;d_model&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
<span class="lineno">95</span> <span class="s1">&#39;d_ffn&#39;</span><span class="p">:</span> <span class="mi">2048</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">96</span> <span class="s1">&#39;d_model&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
<span class="lineno">97</span> <span class="s1">&#39;d_ffn&#39;</span><span class="p">:</span> <span class="mi">2048</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -363,9 +365,9 @@ per epoch</p>
<p>Use <a href="../../optimizers/noam.html">Noam optimizer</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Noam&#39;</span><span class="p">,</span>
<span class="lineno">99</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">1.</span><span class="p">,</span>
<span class="lineno">100</span> <span class="p">})</span></pre></div>
<div class="highlight"><pre><span class="lineno">100</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Noam&#39;</span><span class="p">,</span>
<span class="lineno">101</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">1.</span><span class="p">,</span>
<span class="lineno">102</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -376,7 +378,7 @@ per epoch</p>
<p>Set models for saving and loading</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -387,7 +389,7 @@ per epoch</p>
<p>Start the experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
@ -398,7 +400,7 @@ per epoch</p>
<p>Run training</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">110</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
@ -409,8 +411,8 @@ per epoch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">113</span> <span class="n">main</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">114</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">115</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
</div>

View File

@ -7,6 +7,8 @@ summary: This experiment trains a gMLP based model on Tiny Shakespeare dataset.
# [Pay Attention to MLPs (gMLP)](index.html) Experiment
This is an annotated PyTorch experiment to train a [gMLP model](index.html).
The paper also applies a Stochastic Depth regularization where some layers are removed randomly during training.
We have not implemented that here.
This is based on
[training loop and configurations for a simple transformer auto-regressive NLP task](../basic/autoregressive_experiment.html).