This commit is contained in:
Varuna Jayasiri
2022-09-24 14:41:55 +05:30
parent f4b2d46925
commit 1a49b753e4
2 changed files with 81 additions and 72 deletions

View File

@ -579,14 +579,10 @@
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>If <code class="highlight"><span></span><span class="n">cond</span></code>
is <code class="highlight"><span></span><span class="kc">None</span></code>
we perform self attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="k">if</span> <span class="n">cond</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">169</span> <span class="n">cond</span> <span class="o">=</span> <span class="n">x</span></pre></div>
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">has_cond</span> <span class="o">=</span> <span class="n">cond</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
@ -594,26 +590,39 @@
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>If <code class="highlight"><span></span><span class="n">cond</span></code>
is <code class="highlight"><span></span><span class="kc">None</span></code>
we perform self attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">has_cond</span><span class="p">:</span>
<span class="lineno">171</span> <span class="n">cond</span> <span class="o">=</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Get query, key and value vectors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_q</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">173</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_k</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
<span class="lineno">174</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_v</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
<span class="lineno">175</span>
<span class="lineno">176</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;use flash&#39;</span><span class="p">,</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">)</span>
<div class="highlight"><pre><span class="lineno">174</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_q</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">175</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_k</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
<span class="lineno">176</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_v</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
<span class="lineno">177</span>
<span class="lineno">178</span> <span class="k">if</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">cond</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">178</span> <span class="k">if</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">has_cond</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">179</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="lineno">180</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">181</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">normal_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='section' id='section-40'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-39'>#</a>
<a href='#section-40'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">q</span></code>
are the query vectors before splitting heads, of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_attn</span><span class="p">]</span></code>
@ -630,10 +639,10 @@
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">def</span> <span class="nf">flash_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
<a href='#section-41'>#</a>
</div>
</div>
@ -641,10 +650,10 @@
<div class="highlight"><pre><span class="lineno">190</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;flash&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
<a href='#section-42'>#</a>
</div>
<b>MarkdownException</b> + Italic: not ending with *
</div>
@ -652,10 +661,10 @@
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
<a href='#section-43'>#</a>
</div>
<p>Stack <code class="highlight"><span></span><span class="n">q</span></code>
, <code class="highlight"><span></span><span class="n">k</span></code>
@ -668,10 +677,10 @@
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
<a href='#section-44'>#</a>
</div>
<p>Split the heads </p>
@ -680,10 +689,10 @@
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
<a href='#section-45'>#</a>
</div>
<p>Flash attention works for head sizes <code class="highlight"><span></span><span class="mi">32</span></code>
, <code class="highlight"><span></span><span class="mi">64</span></code>
@ -702,10 +711,10 @@
<span class="lineno">210</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Head size $</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="si">}</span><span class="s1"> too large for Flash Attention&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
<a href='#section-46'>#</a>
</div>
<p>Pad the heads </p>
@ -715,10 +724,10 @@
<span class="lineno">214</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">pad</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
<a href='#section-47'>#</a>
</div>
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
</div>
@ -726,10 +735,10 @@
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">(</span><span class="n">qkv</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
<a href='#section-48'>#</a>
</div>
<p>Truncate the extra head size </p>
@ -738,10 +747,10 @@
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
<a href='#section-49'>#</a>
</div>
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code>
</p>
@ -751,10 +760,10 @@
<div class="highlight"><pre><span class="lineno">223</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
<a href='#section-50'>#</a>
</div>
<p>Map to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
with a linear layer </p>
@ -764,10 +773,10 @@
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='section' id='section-51'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-50'>#</a>
<a href='#section-51'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">q</span></code>
are the query vectors before splitting heads, of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_attn</span><span class="p">]</span></code>
@ -784,10 +793,10 @@
<div class="highlight"><pre><span class="lineno">228</span> <span class="k">def</span> <span class="nf">normal_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
<a href='#section-52'>#</a>
</div>
<p>Split them to heads of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
</p>
@ -799,10 +808,10 @@
<span class="lineno">238</span> <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
<a href='#section-53'>#</a>
</div>
<b>KeyError</b> + '\\frac{Q K^\\top}{\\sqrt{d_{key}}}'
</div>
@ -810,10 +819,10 @@
<div class="highlight"><pre><span class="lineno">241</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bihd,bjhd-&gt;bhij&#39;</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
<a href='#section-54'>#</a>
</div>
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)'
</div>
@ -826,10 +835,10 @@
<span class="lineno">250</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
<a href='#section-55'>#</a>
</div>
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
</div>
@ -837,10 +846,10 @@
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bhij,bjhd-&gt;bihd&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
<a href='#section-56'>#</a>
</div>
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code>
</p>
@ -850,10 +859,10 @@
<div class="highlight"><pre><span class="lineno">256</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
<a href='#section-57'>#</a>
</div>
<p>Map to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
with a linear layer </p>
@ -863,10 +872,10 @@
<div class="highlight"><pre><span class="lineno">258</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='section' id='section-58'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-57'>#</a>
<a href='#section-58'>#</a>
</div>
<h3>Feed-Forward Network</h3>
@ -875,10 +884,10 @@
<div class="highlight"><pre><span class="lineno">261</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='section' id='section-59'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-58'>#</a>
<a href='#section-59'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the input embedding size </li>
@ -890,10 +899,10 @@
<div class="highlight"><pre><span class="lineno">266</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_mult</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
<a href='#section-60'>#</a>
</div>
</div>
@ -906,10 +915,10 @@
<span class="lineno">276</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
<a href='#section-61'>#</a>
</div>
</div>
@ -918,10 +927,10 @@
<span class="lineno">279</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='section' id='section-62'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-61'>#</a>
<a href='#section-62'>#</a>
</div>
<b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
</div>
@ -929,10 +938,10 @@
<div class="highlight"><pre><span class="lineno">282</span><span class="k">class</span> <span class="nc">GeGLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
<a href='#section-63'>#</a>
</div>
</div>
@ -941,10 +950,10 @@
<span class="lineno">290</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
<a href='#section-64'>#</a>
</div>
<b>KeyError</b> + 'xW + b'
</div>
@ -952,10 +961,10 @@
<div class="highlight"><pre><span class="lineno">292</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
<a href='#section-65'>#</a>
</div>
</div>
@ -963,10 +972,10 @@
<div class="highlight"><pre><span class="lineno">294</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
<a href='#section-66'>#</a>
</div>
<b>KeyError</b> + 'xW + b'
</div>
@ -974,10 +983,10 @@
<div class="highlight"><pre><span class="lineno">296</span> <span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
<a href='#section-67'>#</a>
</div>
<b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
</div>

View File

@ -164,8 +164,10 @@ class CrossAttention(nn.Module):
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
has_cond = cond is not None
# If `cond` is `None` we perform self attention
if cond is None:
if not has_cond:
cond = x
# Get query, key and value vectors
@ -173,9 +175,7 @@ class CrossAttention(nn.Module):
k = self.to_k(cond)
v = self.to_v(cond)
print('use flash', CrossAttention.use_flash_attention, self.flash)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
return self.flash_attention(q, k, v)
else:
return self.normal_attention(q, k, v)