This commit is contained in:
Varuna Jayasiri
2022-09-24 14:39:10 +05:30
parent de36f9b6be
commit eb92824e58
2 changed files with 59 additions and 55 deletions

View File

@ -602,10 +602,12 @@
<span class="lineno">173</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_k</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
<span class="lineno">174</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_v</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
<span class="lineno">175</span>
<span class="lineno">176</span> <span class="k">if</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">cond</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">177</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="lineno">178</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">179</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">normal_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
<span class="lineno">176</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;use flash&#39;</span><span class="p">,</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span><span class="p">)</span>
<span class="lineno">177</span>
<span class="lineno">178</span> <span class="k">if</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">cond</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">179</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="lineno">180</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">181</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">normal_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
@ -625,7 +627,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="k">def</span> <span class="nf">flash_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">def</span> <span class="nf">flash_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
@ -636,7 +638,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;flash&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">190</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;flash&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
@ -647,7 +649,7 @@
<b>MarkdownException</b> + Italic: not ending with *
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span></pre></div>
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
@ -663,7 +665,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
@ -675,7 +677,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
@ -690,14 +692,14 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">32</span><span class="p">:</span>
<span class="lineno">202</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">32</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">203</span> <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">64</span><span class="p">:</span>
<span class="lineno">204</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">64</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">205</span> <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">206</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">128</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">207</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">208</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Head size $</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="si">}</span><span class="s1"> too large for Flash Attention&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">32</span><span class="p">:</span>
<span class="lineno">204</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">32</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">205</span> <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">64</span><span class="p">:</span>
<span class="lineno">206</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">64</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">207</span> <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">208</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">128</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span>
<span class="lineno">209</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">210</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Head size $</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="si">}</span><span class="s1"> too large for Flash Attention&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
@ -709,8 +711,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">if</span> <span class="n">pad</span><span class="p">:</span>
<span class="lineno">212</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">pad</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">213</span> <span class="k">if</span> <span class="n">pad</span><span class="p">:</span>
<span class="lineno">214</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">pad</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
@ -721,7 +723,7 @@
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">(</span><span class="n">qkv</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">(</span><span class="n">qkv</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
@ -733,7 +735,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
@ -746,7 +748,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">223</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
@ -759,7 +761,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
@ -779,7 +781,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">def</span> <span class="nf">normal_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">228</span> <span class="k">def</span> <span class="nf">normal_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
@ -792,9 +794,9 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">q</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">235</span> <span class="n">k</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">236</span> <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">236</span> <span class="n">q</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">237</span> <span class="n">k</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">238</span> <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
@ -805,7 +807,7 @@
<b>KeyError</b> + '\\frac{Q K^\\top}{\\sqrt{d_{key}}}'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bihd,bjhd-&gt;bhij&#39;</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
<div class="highlight"><pre><span class="lineno">241</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bihd,bjhd-&gt;bhij&#39;</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
@ -816,12 +818,12 @@
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_inplace</span><span class="p">:</span>
<span class="lineno">244</span> <span class="n">half</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
<span class="lineno">245</span> <span class="n">attn</span><span class="p">[</span><span class="n">half</span><span class="p">:]</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[</span><span class="n">half</span><span class="p">:]</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">246</span> <span class="n">attn</span><span class="p">[:</span><span class="n">half</span><span class="p">]</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[:</span><span class="n">half</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">247</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">248</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">245</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_inplace</span><span class="p">:</span>
<span class="lineno">246</span> <span class="n">half</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
<span class="lineno">247</span> <span class="n">attn</span><span class="p">[</span><span class="n">half</span><span class="p">:]</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[</span><span class="n">half</span><span class="p">:]</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">248</span> <span class="n">attn</span><span class="p">[:</span><span class="n">half</span><span class="p">]</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[:</span><span class="n">half</span><span class="p">]</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">249</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">250</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
@ -832,7 +834,7 @@
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bhij,bjhd-&gt;bihd&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bhij,bjhd-&gt;bihd&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
@ -845,7 +847,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">256</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
@ -858,7 +860,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">256</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">258</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
@ -870,7 +872,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">259</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">261</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
@ -885,7 +887,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">264</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_mult</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">266</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_mult</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
@ -896,12 +898,12 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">269</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">270</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="lineno">271</span> <span class="n">GeGLU</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">*</span> <span class="n">d_mult</span><span class="p">),</span>
<span class="lineno">272</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span>
<span class="lineno">273</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span> <span class="o">*</span> <span class="n">d_mult</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">274</span> <span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">271</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">272</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="lineno">273</span> <span class="n">GeGLU</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">*</span> <span class="n">d_mult</span><span class="p">),</span>
<span class="lineno">274</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span>
<span class="lineno">275</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span> <span class="o">*</span> <span class="n">d_mult</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">276</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
@ -912,8 +914,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">277</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">278</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">279</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
@ -924,7 +926,7 @@
<b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">280</span><span class="k">class</span> <span class="nc">GeGLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">282</span><span class="k">class</span> <span class="nc">GeGLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
@ -935,8 +937,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">287</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_in</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_out</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">288</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">289</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_in</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_out</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">290</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
@ -947,7 +949,7 @@
<b>KeyError</b> + 'xW + b'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">290</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">292</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
@ -958,7 +960,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">294</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
@ -969,7 +971,7 @@
<b>KeyError</b> + 'xW + b'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">296</span> <span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
@ -980,7 +982,7 @@
<b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">296</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">F</span><span class="o">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">gate</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">298</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">F</span><span class="o">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">gate</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -173,6 +173,8 @@ class CrossAttention(nn.Module):
k = self.to_k(cond)
v = self.to_v(cond)
print('use flash', CrossAttention.use_flash_attention)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
return self.flash_attention(q, k, v)
else:
@ -186,7 +188,7 @@ class CrossAttention(nn.Module):
"""
print('flash')
# Get batch size and number of elements along sequence axis (width * height)
batch_size, seq_len, _ = q.shape