|
|
|
@ -96,8 +96,7 @@ In the update, some features of $c$ are cleared with a forget gate $f$,
|
|
|
|
|
and some features $i$ are added through a gate $g$.</p>
|
|
|
|
|
<p>The new short term memory is the $\tanh$ of the long-term memory
|
|
|
|
|
multiplied by the output gate $o$.</p>
|
|
|
|
|
<p>Note that the cell doesn’t look at long term memory $c$ when doing the update
|
|
|
|
|
for the update. It only modifies it.
|
|
|
|
|
<p>Note that the cell doesn’t look at long term memory $c$ when doing the update. It only modifies it.
|
|
|
|
|
Also $c$ never goes through a linear transformation.
|
|
|
|
|
This is what solves vanishing and exploding gradients.</p>
|
|
|
|
|
<p>Here’s the update rule.</p>
|
|
|
|
@ -131,8 +130,8 @@ o_t &= lin_x^o(x_t) + lin_h^o(h_{t-1})
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">59</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">layer_norm</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
|
|
|
|
|
<span class="lineno">60</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">58</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">layer_norm</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
|
|
|
|
|
<span class="lineno">59</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-3'>
|
|
|
|
@ -155,7 +154,7 @@ One of them doesn’t need a bias since we add the transformations.</p>
|
|
|
|
|
<p>This combines $lin_x^i$, $lin_x^f$, $lin_x^g$, and $lin_x^o$ transformations.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">66</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">hidden_size</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">hidden_size</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-5'>
|
|
|
|
@ -166,7 +165,7 @@ One of them doesn’t need a bias since we add the transformations.</p>
|
|
|
|
|
<p>This combines $lin_h^i$, $lin_h^f$, $lin_h^g$, and $lin_h^o$ transformations.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-6'>
|
|
|
|
@ -180,12 +179,12 @@ $i$, $f$, $g$ and $o$ embeddings are normalized and $c_t$ is normalized in
|
|
|
|
|
$h_t = o_t \odot \tanh(\mathop{LN}(c_t))$</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">if</span> <span class="n">layer_norm</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)])</span>
|
|
|
|
|
<span class="lineno">77</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm_c</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">78</span> <span class="k">else</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">79</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)])</span>
|
|
|
|
|
<span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm_c</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">74</span> <span class="k">if</span> <span class="n">layer_norm</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)])</span>
|
|
|
|
|
<span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm_c</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">77</span> <span class="k">else</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)])</span>
|
|
|
|
|
<span class="lineno">79</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm_c</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-7'>
|
|
|
|
@ -196,7 +195,7 @@ $h_t = o_t \odot \tanh(\mathop{LN}(c_t))$</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">82</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">c</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">81</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">c</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-8'>
|
|
|
|
@ -208,7 +207,7 @@ $h_t = o_t \odot \tanh(\mathop{LN}(c_t))$</p>
|
|
|
|
|
using the same linear layers.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">ifgo</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_lin</span><span class="p">(</span><span class="n">h</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_lin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">ifgo</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_lin</span><span class="p">(</span><span class="n">h</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_lin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-9'>
|
|
|
|
@ -219,7 +218,7 @@ using the same linear layers.</p>
|
|
|
|
|
<p>Each layer produces an output of 4 times the <code>hidden_size</code> and we split them</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">ifgo</span> <span class="o">=</span> <span class="n">ifgo</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">86</span> <span class="n">ifgo</span> <span class="o">=</span> <span class="n">ifgo</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-10'>
|
|
|
|
@ -230,7 +229,7 @@ using the same linear layers.</p>
|
|
|
|
|
<p>Apply layer normalization (not in original paper, but gives better results)</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">ifgo</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">ifgo</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">ifgo</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">ifgo</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-11'>
|
|
|
|
@ -243,7 +242,7 @@ using the same linear layers.</p>
|
|
|
|
|
</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">i</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">g</span><span class="p">,</span> <span class="n">o</span> <span class="o">=</span> <span class="n">ifgo</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">i</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">g</span><span class="p">,</span> <span class="n">o</span> <span class="o">=</span> <span class="n">ifgo</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-12'>
|
|
|
|
@ -256,7 +255,7 @@ using the same linear layers.</p>
|
|
|
|
|
</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">c_next</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="o">*</span> <span class="n">c</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">g</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">c_next</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="o">*</span> <span class="n">c</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">g</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-13'>
|
|
|
|
@ -269,9 +268,9 @@ using the same linear layers.</p>
|
|
|
|
|
Optionally, apply layer norm to $c_t$</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">h_next</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">o</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_norm_c</span><span class="p">(</span><span class="n">c_next</span><span class="p">))</span>
|
|
|
|
|
<span class="lineno">101</span>
|
|
|
|
|
<span class="lineno">102</span> <span class="k">return</span> <span class="n">h_next</span><span class="p">,</span> <span class="n">c_next</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">99</span> <span class="n">h_next</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">o</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_norm_c</span><span class="p">(</span><span class="n">c_next</span><span class="p">))</span>
|
|
|
|
|
<span class="lineno">100</span>
|
|
|
|
|
<span class="lineno">101</span> <span class="k">return</span> <span class="n">h_next</span><span class="p">,</span> <span class="n">c_next</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-14'>
|
|
|
|
@ -282,7 +281,7 @@ Optionally, apply layer norm to $c_t$</p>
|
|
|
|
|
<h2>Multilayer LSTM</h2>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">105</span><span class="k">class</span> <span class="nc">LSTM</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">104</span><span class="k">class</span> <span class="nc">LSTM</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-15'>
|
|
|
|
@ -293,7 +292,7 @@ Optionally, apply layer norm to $c_t$</p>
|
|
|
|
|
<p>Create a network of <code>n_layers</code> of LSTM.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">109</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-16'>
|
|
|
|
@ -304,9 +303,9 @@ Optionally, apply layer norm to $c_t$</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">115</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
|
|
|
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">=</span> <span class="n">n_layers</span>
|
|
|
|
|
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">114</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
|
|
|
<span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">=</span> <span class="n">n_layers</span>
|
|
|
|
|
<span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-17'>
|
|
|
|
@ -318,8 +317,8 @@ Optionally, apply layer norm to $c_t$</p>
|
|
|
|
|
Rest of the layers get the input from the layer below</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">cells</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)]</span> <span class="o">+</span>
|
|
|
|
|
<span class="lineno">121</span> <span class="p">[</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)])</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">119</span> <span class="bp">self</span><span class="o">.</span><span class="n">cells</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)]</span> <span class="o">+</span>
|
|
|
|
|
<span class="lineno">120</span> <span class="p">[</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)])</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-18'>
|
|
|
|
@ -331,7 +330,7 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
<code>state</code> is a tuple of $h$ and $c$, each with a shape of <code>[batch_size, hidden_size]</code>.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">123</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-19'>
|
|
|
|
@ -342,7 +341,7 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">n_steps</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">n_steps</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-20'>
|
|
|
|
@ -353,11 +352,11 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
<p>Initialize the state if <code>None</code></p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">132</span> <span class="n">h</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)]</span>
|
|
|
|
|
<span class="lineno">133</span> <span class="n">c</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)]</span>
|
|
|
|
|
<span class="lineno">134</span> <span class="k">else</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">135</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span> <span class="o">=</span> <span class="n">state</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">131</span> <span class="n">h</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)]</span>
|
|
|
|
|
<span class="lineno">132</span> <span class="n">c</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)]</span>
|
|
|
|
|
<span class="lineno">133</span> <span class="k">else</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">134</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span> <span class="o">=</span> <span class="n">state</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-21'>
|
|
|
|
@ -369,7 +368,7 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
📝 You can just work with the tensor itself but this is easier to debug</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">unbind</span><span class="p">(</span><span class="n">h</span><span class="p">)),</span> <span class="nb">list</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">unbind</span><span class="p">(</span><span class="n">c</span><span class="p">))</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">unbind</span><span class="p">(</span><span class="n">h</span><span class="p">)),</span> <span class="nb">list</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">unbind</span><span class="p">(</span><span class="n">c</span><span class="p">))</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-22'>
|
|
|
|
@ -380,8 +379,8 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
<p>Array to collect the outputs of the final layer at each time step.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">out</span> <span class="o">=</span> <span class="p">[]</span>
|
|
|
|
|
<span class="lineno">142</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_steps</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">out</span> <span class="o">=</span> <span class="p">[]</span>
|
|
|
|
|
<span class="lineno">141</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_steps</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-23'>
|
|
|
|
@ -392,7 +391,7 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
<p>Input to the first layer is the input itself</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">inp</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">t</span><span class="p">]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">inp</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">t</span><span class="p">]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-24'>
|
|
|
|
@ -403,7 +402,7 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
<p>Loop through the layers</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">146</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">145</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-25'>
|
|
|
|
@ -414,7 +413,7 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
<p>Get the state of the layer</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">h</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="n">c</span><span class="p">[</span><span class="n">layer</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cells</span><span class="p">[</span><span class="n">layer</span><span class="p">](</span><span class="n">inp</span><span class="p">,</span> <span class="n">h</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="n">c</span><span class="p">[</span><span class="n">layer</span><span class="p">])</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">h</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="n">c</span><span class="p">[</span><span class="n">layer</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cells</span><span class="p">[</span><span class="n">layer</span><span class="p">](</span><span class="n">inp</span><span class="p">,</span> <span class="n">h</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="n">c</span><span class="p">[</span><span class="n">layer</span><span class="p">])</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-26'>
|
|
|
|
@ -425,7 +424,7 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
<p>Input to the next layer is the state of this layer</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">inp</span> <span class="o">=</span> <span class="n">h</span><span class="p">[</span><span class="n">layer</span><span class="p">]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">inp</span> <span class="o">=</span> <span class="n">h</span><span class="p">[</span><span class="n">layer</span><span class="p">]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-27'>
|
|
|
|
@ -436,7 +435,7 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
<p>Collect the output $h$ of the final layer</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">152</span> <span class="n">out</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">h</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">out</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">h</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-28'>
|
|
|
|
@ -447,11 +446,11 @@ Rest of the layers get the input from the layer below</p>
|
|
|
|
|
<p>Stack the outputs and states</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">155</span> <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">156</span> <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">157</span> <span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">c</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">158</span>
|
|
|
|
|
<span class="lineno">159</span> <span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">155</span> <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">156</span> <span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">c</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">157</span>
|
|
|
|
|
<span class="lineno">158</span> <span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|