This commit is contained in:
Varuna Jayasiri
2022-09-24 14:37:45 +05:30
parent 160f25a938
commit de36f9b6be
2 changed files with 144 additions and 131 deletions

View File

@ -633,10 +633,10 @@
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<b>MarkdownException</b> + Italic: not ending with *
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span></pre></div>
<div class="highlight"><pre><span class="lineno">188</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;flash&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
@ -644,6 +644,17 @@
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<b>MarkdownException</b> + Italic: not ending with *
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Stack <code class="highlight"><span></span><span class="n">q</span></code>
, <code class="highlight"><span></span><span class="n">k</span></code>
, <code class="highlight"><span></span><span class="n">v</span></code>
@ -652,19 +663,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Split the heads </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
@ -672,6 +671,18 @@
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Split the heads </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Flash attention works for head sizes <code class="highlight"><span></span><span class="mi">32</span></code>
, <code class="highlight"><span></span><span class="mi">64</span></code>
and <code class="highlight"><span></span><span class="mi">128</span></code>
@ -679,27 +690,14 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">32</span><span class="p">:</span>
<span class="lineno">200</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">32</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">201</span> <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">64</span><span class="p">:</span>
<span class="lineno">202</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">64</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">203</span> <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">204</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">128</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">205</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">206</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Head size $</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="si">}</span><span class="s1"> too large for Flash Attention&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Pad the heads </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="k">if</span> <span class="n">pad</span><span class="p">:</span>
<span class="lineno">210</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">pad</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">201</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">32</span><span class="p">:</span>
<span class="lineno">202</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">32</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">203</span> <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">64</span><span class="p">:</span>
<span class="lineno">204</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">64</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">205</span> <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">206</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">128</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">207</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">208</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Head size $</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="si">}</span><span class="s1"> too large for Flash Attention&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
@ -707,10 +705,12 @@
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
<p>Pad the heads </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span> <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">(</span><span class="n">qkv</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">if</span> <span class="n">pad</span><span class="p">:</span>
<span class="lineno">212</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">pad</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
@ -718,11 +718,10 @@
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Truncate the extra head size </p>
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">(</span><span class="n">qkv</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
@ -730,12 +729,11 @@
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code>
</p>
<p>Truncate the extra head size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
@ -743,18 +741,31 @@
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Map to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
with a linear layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">224</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='section' id='section-50'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-49'>#</a>
<a href='#section-50'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">q</span></code>
are the query vectors before splitting heads, of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_attn</span><span class="p">]</span></code>
@ -768,22 +779,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="k">def</span> <span class="nf">normal_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Split them to heads of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span> <span class="n">q</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">233</span> <span class="n">k</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">234</span> <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">def</span> <span class="nf">normal_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
@ -791,10 +787,14 @@
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<b>KeyError</b> + '\\frac{Q K^\\top}{\\sqrt{d_{key}}}'
<p>Split them to heads of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bihd,bjhd-&gt;bhij&#39;</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">q</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">235</span> <span class="n">k</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">236</span> <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
@ -802,15 +802,10 @@
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)'
<b>KeyError</b> + '\\frac{Q K^\\top}{\\sqrt{d_{key}}}'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">241</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_inplace</span><span class="p">:</span>
<span class="lineno">242</span> <span class="n">half</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
<span class="lineno">243</span> <span class="n">attn</span><span class="p">[</span><span class="n">half</span><span class="p">:]</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[</span><span class="n">half</span><span class="p">:]</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">244</span> <span class="n">attn</span><span class="p">[:</span><span class="n">half</span><span class="p">]</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[:</span><span class="n">half</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">245</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">246</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">239</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bihd,bjhd-&gt;bhij&#39;</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
@ -818,10 +813,15 @@
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span> <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bhij,bjhd-&gt;bihd&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">243</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_inplace</span><span class="p">:</span>
<span class="lineno">244</span> <span class="n">half</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
<span class="lineno">245</span> <span class="n">attn</span><span class="p">[</span><span class="n">half</span><span class="p">:]</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[</span><span class="n">half</span><span class="p">:]</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">246</span> <span class="n">attn</span><span class="p">[:</span><span class="n">half</span><span class="p">]</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[:</span><span class="n">half</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">247</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">248</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
@ -829,12 +829,10 @@
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code>
</p>
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bhij,bjhd-&gt;bihd&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
@ -842,24 +840,25 @@
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>Map to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
with a linear layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<h3>Feed-Forward Network</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">256</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
@ -867,6 +866,18 @@
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<h3>Feed-Forward Network</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">259</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the input embedding size </li>
<li><code class="highlight"><span></span><span class="n">d_mult</span></code>
@ -874,23 +885,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">262</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_mult</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">267</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">268</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="lineno">269</span> <span class="n">GeGLU</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">*</span> <span class="n">d_mult</span><span class="p">),</span>
<span class="lineno">270</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span>
<span class="lineno">271</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span> <span class="o">*</span> <span class="n">d_mult</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">272</span> <span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">264</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_mult</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
@ -901,31 +896,35 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">274</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">275</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">269</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">270</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="lineno">271</span> <span class="n">GeGLU</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">*</span> <span class="n">d_mult</span><span class="p">),</span>
<span class="lineno">272</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span>
<span class="lineno">273</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span> <span class="o">*</span> <span class="n">d_mult</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">274</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">278</span><span class="k">class</span> <span class="nc">GeGLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
<a href='#section-60'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_in</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_out</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">286</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">276</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">277</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">280</span><span class="k">class</span> <span class="nc">GeGLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
@ -933,10 +932,11 @@
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<b>KeyError</b> + 'xW + b'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">287</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_in</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_out</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">288</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
@ -944,10 +944,10 @@
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<b>KeyError</b> + 'xW + b'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">290</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">290</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
@ -955,10 +955,10 @@
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<b>KeyError</b> + 'xW + b'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">292</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
@ -966,10 +966,21 @@
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<b>KeyError</b> + 'xW + b'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">F</span><span class="o">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">gate</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">296</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">F</span><span class="o">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">gate</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -185,6 +185,8 @@ class CrossAttention(nn.Module):
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
"""
print('flash')
# Get batch size and number of elements along sequence axis (width * height)
batch_size, seq_len, _ = q.shape