mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
fix
This commit is contained in:
@ -579,14 +579,10 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-37'>#</a>
|
||||
</div>
|
||||
<p>If <code class="highlight"><span></span><span class="n">cond</span></code>
|
||||
is <code class="highlight"><span></span><span class="kc">None</span></code>
|
||||
we perform self attention </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">168</span> <span class="k">if</span> <span class="n">cond</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">169</span> <span class="n">cond</span> <span class="o">=</span> <span class="n">x</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">has_cond</span> <span class="o">=</span> <span class="n">cond</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-38'>
|
||||
@ -594,26 +590,39 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-38'>#</a>
|
||||
</div>
|
||||
<p>If <code class="highlight"><span></span><span class="n">cond</span></code>
|
||||
is <code class="highlight"><span></span><span class="kc">None</span></code>
|
||||
we perform self attention </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">170</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">has_cond</span><span class="p">:</span>
|
||||
<span class="lineno">171</span> <span class="n">cond</span> <span class="o">=</span> <span class="n">x</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-39'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-39'>#</a>
|
||||
</div>
|
||||
<p>Get query, key and value vectors </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_q</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="lineno">173</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_k</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
|
||||
<span class="lineno">174</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_v</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
|
||||
<span class="lineno">175</span>
|
||||
<span class="lineno">176</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'use flash'</span><span class="p">,</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">)</span>
|
||||
<div class="highlight"><pre><span class="lineno">174</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_q</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="lineno">175</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_k</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
|
||||
<span class="lineno">176</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_v</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
|
||||
<span class="lineno">177</span>
|
||||
<span class="lineno">178</span> <span class="k">if</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">cond</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o"><=</span> <span class="mi">128</span><span class="p">:</span>
|
||||
<span class="lineno">178</span> <span class="k">if</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">has_cond</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o"><=</span> <span class="mi">128</span><span class="p">:</span>
|
||||
<span class="lineno">179</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
|
||||
<span class="lineno">180</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">181</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">normal_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-39'>
|
||||
<div class='section' id='section-40'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-39'>#</a>
|
||||
<a href='#section-40'>#</a>
|
||||
</div>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">q</span></code>
|
||||
are the query vectors before splitting heads, of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_attn</span><span class="p">]</span></code>
|
||||
@ -630,10 +639,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">def</span> <span class="nf">flash_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-40'>
|
||||
<div class='section' id='section-41'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-40'>#</a>
|
||||
<a href='#section-41'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
@ -641,10 +650,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">190</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'flash'</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-41'>
|
||||
<div class='section' id='section-42'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-41'>#</a>
|
||||
<a href='#section-42'>#</a>
|
||||
</div>
|
||||
<b>MarkdownException</b> + Italic: not ending with *
|
||||
</div>
|
||||
@ -652,10 +661,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-42'>
|
||||
<div class='section' id='section-43'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-42'>#</a>
|
||||
<a href='#section-43'>#</a>
|
||||
</div>
|
||||
<p>Stack <code class="highlight"><span></span><span class="n">q</span></code>
|
||||
, <code class="highlight"><span></span><span class="n">k</span></code>
|
||||
@ -668,10 +677,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-43'>
|
||||
<div class='section' id='section-44'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-43'>#</a>
|
||||
<a href='#section-44'>#</a>
|
||||
</div>
|
||||
<p>Split the heads </p>
|
||||
|
||||
@ -680,10 +689,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-44'>
|
||||
<div class='section' id='section-45'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-44'>#</a>
|
||||
<a href='#section-45'>#</a>
|
||||
</div>
|
||||
<p>Flash attention works for head sizes <code class="highlight"><span></span><span class="mi">32</span></code>
|
||||
, <code class="highlight"><span></span><span class="mi">64</span></code>
|
||||
@ -702,10 +711,10 @@
|
||||
<span class="lineno">210</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Head size $</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="si">}</span><span class="s1"> too large for Flash Attention'</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-45'>
|
||||
<div class='section' id='section-46'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-45'>#</a>
|
||||
<a href='#section-46'>#</a>
|
||||
</div>
|
||||
<p>Pad the heads </p>
|
||||
|
||||
@ -715,10 +724,10 @@
|
||||
<span class="lineno">214</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">pad</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-46'>
|
||||
<div class='section' id='section-47'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-46'>#</a>
|
||||
<a href='#section-47'>#</a>
|
||||
</div>
|
||||
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
|
||||
</div>
|
||||
@ -726,10 +735,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">(</span><span class="n">qkv</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-47'>
|
||||
<div class='section' id='section-48'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-47'>#</a>
|
||||
<a href='#section-48'>#</a>
|
||||
</div>
|
||||
<p>Truncate the extra head size </p>
|
||||
|
||||
@ -738,10 +747,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-48'>
|
||||
<div class='section' id='section-49'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-48'>#</a>
|
||||
<a href='#section-49'>#</a>
|
||||
</div>
|
||||
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code>
|
||||
</p>
|
||||
@ -751,10 +760,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">223</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-49'>
|
||||
<div class='section' id='section-50'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-49'>#</a>
|
||||
<a href='#section-50'>#</a>
|
||||
</div>
|
||||
<p>Map to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||||
with a linear layer </p>
|
||||
@ -764,10 +773,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-50'>
|
||||
<div class='section' id='section-51'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-50'>#</a>
|
||||
<a href='#section-51'>#</a>
|
||||
</div>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">q</span></code>
|
||||
are the query vectors before splitting heads, of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_attn</span><span class="p">]</span></code>
|
||||
@ -784,10 +793,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">228</span> <span class="k">def</span> <span class="nf">normal_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-51'>
|
||||
<div class='section' id='section-52'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-51'>#</a>
|
||||
<a href='#section-52'>#</a>
|
||||
</div>
|
||||
<p>Split them to heads of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code>
|
||||
</p>
|
||||
@ -799,10 +808,10 @@
|
||||
<span class="lineno">238</span> <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-52'>
|
||||
<div class='section' id='section-53'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-52'>#</a>
|
||||
<a href='#section-53'>#</a>
|
||||
</div>
|
||||
<b>KeyError</b> + '\\frac{Q K^\\top}{\\sqrt{d_{key}}}'
|
||||
</div>
|
||||
@ -810,10 +819,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">241</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bihd,bjhd->bhij'</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-53'>
|
||||
<div class='section' id='section-54'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-53'>#</a>
|
||||
<a href='#section-54'>#</a>
|
||||
</div>
|
||||
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)'
|
||||
</div>
|
||||
@ -826,10 +835,10 @@
|
||||
<span class="lineno">250</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-54'>
|
||||
<div class='section' id='section-55'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-54'>#</a>
|
||||
<a href='#section-55'>#</a>
|
||||
</div>
|
||||
<b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
|
||||
</div>
|
||||
@ -837,10 +846,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhij,bjhd->bihd'</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-55'>
|
||||
<div class='section' id='section-56'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-55'>#</a>
|
||||
<a href='#section-56'>#</a>
|
||||
</div>
|
||||
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code>
|
||||
</p>
|
||||
@ -850,10 +859,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">256</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-56'>
|
||||
<div class='section' id='section-57'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-56'>#</a>
|
||||
<a href='#section-57'>#</a>
|
||||
</div>
|
||||
<p>Map to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||||
with a linear layer </p>
|
||||
@ -863,10 +872,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">258</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-57'>
|
||||
<div class='section' id='section-58'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-57'>#</a>
|
||||
<a href='#section-58'>#</a>
|
||||
</div>
|
||||
<h3>Feed-Forward Network</h3>
|
||||
|
||||
@ -875,10 +884,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">261</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-58'>
|
||||
<div class='section' id='section-59'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-58'>#</a>
|
||||
<a href='#section-59'>#</a>
|
||||
</div>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||||
is the input embedding size </li>
|
||||
@ -890,10 +899,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">266</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_mult</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-59'>
|
||||
<div class='section' id='section-60'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-59'>#</a>
|
||||
<a href='#section-60'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
@ -906,10 +915,10 @@
|
||||
<span class="lineno">276</span> <span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-60'>
|
||||
<div class='section' id='section-61'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-60'>#</a>
|
||||
<a href='#section-61'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
@ -918,10 +927,10 @@
|
||||
<span class="lineno">279</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-61'>
|
||||
<div class='section' id='section-62'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-61'>#</a>
|
||||
<a href='#section-62'>#</a>
|
||||
</div>
|
||||
<b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
|
||||
</div>
|
||||
@ -929,10 +938,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">282</span><span class="k">class</span> <span class="nc">GeGLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-62'>
|
||||
<div class='section' id='section-63'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-62'>#</a>
|
||||
<a href='#section-63'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
@ -941,10 +950,10 @@
|
||||
<span class="lineno">290</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-63'>
|
||||
<div class='section' id='section-64'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-63'>#</a>
|
||||
<a href='#section-64'>#</a>
|
||||
</div>
|
||||
<b>KeyError</b> + 'xW + b'
|
||||
</div>
|
||||
@ -952,10 +961,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">292</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-64'>
|
||||
<div class='section' id='section-65'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-64'>#</a>
|
||||
<a href='#section-65'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
@ -963,10 +972,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">294</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-65'>
|
||||
<div class='section' id='section-66'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-65'>#</a>
|
||||
<a href='#section-66'>#</a>
|
||||
</div>
|
||||
<b>KeyError</b> + 'xW + b'
|
||||
</div>
|
||||
@ -974,10 +983,10 @@
|
||||
<div class="highlight"><pre><span class="lineno">296</span> <span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-66'>
|
||||
<div class='section' id='section-67'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-66'>#</a>
|
||||
<a href='#section-67'>#</a>
|
||||
</div>
|
||||
<b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
|
||||
</div>
|
||||
|
||||
@ -164,8 +164,10 @@ class CrossAttention(nn.Module):
|
||||
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
||||
"""
|
||||
|
||||
has_cond = cond is not None
|
||||
|
||||
# If `cond` is `None` we perform self attention
|
||||
if cond is None:
|
||||
if not has_cond:
|
||||
cond = x
|
||||
|
||||
# Get query, key and value vectors
|
||||
@ -173,9 +175,7 @@ class CrossAttention(nn.Module):
|
||||
k = self.to_k(cond)
|
||||
v = self.to_v(cond)
|
||||
|
||||
print('use flash', CrossAttention.use_flash_attention, self.flash)
|
||||
|
||||
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
|
||||
if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
|
||||
return self.flash_attention(q, k, v)
|
||||
else:
|
||||
return self.normal_attention(q, k, v)
|
||||
|
||||
Reference in New Issue
Block a user