mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	fix
This commit is contained in:
		| @ -579,14 +579,10 @@ | |||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-37'>#</a> |                 <a href='#section-37'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>If <code  class="highlight"><span></span><span class="n">cond</span></code> |  | ||||||
|  is <code  class="highlight"><span></span><span class="kc">None</span></code> |  | ||||||
|  we perform self attention </p> |  | ||||||
|              |              | ||||||
|         </div> |         </div> | ||||||
|         <div class='code'> |         <div class='code'> | ||||||
|             <div class="highlight"><pre><span class="lineno">168</span>        <span class="k">if</span> <span class="n">cond</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> |             <div class="highlight"><pre><span class="lineno">167</span>        <span class="n">has_cond</span> <span class="o">=</span> <span class="n">cond</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span></pre></div> | ||||||
| <span class="lineno">169</span>            <span class="n">cond</span> <span class="o">=</span> <span class="n">x</span></pre></div> |  | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-38'> |     <div class='section' id='section-38'> | ||||||
| @ -594,26 +590,39 @@ | |||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-38'>#</a> |                 <a href='#section-38'>#</a> | ||||||
|             </div> |             </div> | ||||||
|  |             <p>If <code  class="highlight"><span></span><span class="n">cond</span></code> | ||||||
|  |  is <code  class="highlight"><span></span><span class="kc">None</span></code> | ||||||
|  |  we perform self attention </p> | ||||||
|  |  | ||||||
|  |         </div> | ||||||
|  |         <div class='code'> | ||||||
|  |             <div class="highlight"><pre><span class="lineno">170</span>        <span class="k">if</span> <span class="ow">not</span> <span class="n">has_cond</span><span class="p">:</span> | ||||||
|  | <span class="lineno">171</span>            <span class="n">cond</span> <span class="o">=</span> <span class="n">x</span></pre></div> | ||||||
|  |         </div> | ||||||
|  |     </div> | ||||||
|  |     <div class='section' id='section-39'> | ||||||
|  |         <div class='docs'> | ||||||
|  |             <div class='section-link'> | ||||||
|  |                 <a href='#section-39'>#</a> | ||||||
|  |             </div> | ||||||
|             <p>Get query, key and value vectors </p> |             <p>Get query, key and value vectors </p> | ||||||
|  |  | ||||||
|         </div> |         </div> | ||||||
|         <div class='code'> |         <div class='code'> | ||||||
|             <div class="highlight"><pre><span class="lineno">172</span>        <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_q</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> |             <div class="highlight"><pre><span class="lineno">174</span>        <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_q</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> | ||||||
| <span class="lineno">173</span>        <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_k</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span> | <span class="lineno">175</span>        <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_k</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span> | ||||||
| <span class="lineno">174</span>        <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_v</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span> | <span class="lineno">176</span>        <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_v</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span> | ||||||
| <span class="lineno">175</span> |  | ||||||
| <span class="lineno">176</span>        <span class="nb">print</span><span class="p">(</span><span class="s1">'use flash'</span><span class="p">,</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">)</span> |  | ||||||
| <span class="lineno">177</span> | <span class="lineno">177</span> | ||||||
| <span class="lineno">178</span>        <span class="k">if</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">cond</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o"><=</span> <span class="mi">128</span><span class="p">:</span> | <span class="lineno">178</span>        <span class="k">if</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">has_cond</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o"><=</span> <span class="mi">128</span><span class="p">:</span> | ||||||
| <span class="lineno">179</span>            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span> | <span class="lineno">179</span>            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span> | ||||||
| <span class="lineno">180</span>        <span class="k">else</span><span class="p">:</span> | <span class="lineno">180</span>        <span class="k">else</span><span class="p">:</span> | ||||||
| <span class="lineno">181</span>            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">normal_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div> | <span class="lineno">181</span>            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">normal_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-39'> |     <div class='section' id='section-40'> | ||||||
|         <div class='docs doc-strings'> |         <div class='docs doc-strings'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-39'>#</a> |                 <a href='#section-40'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <ul><li><code  class="highlight"><span></span><span class="n">q</span></code> |             <ul><li><code  class="highlight"><span></span><span class="n">q</span></code> | ||||||
|   are the query vectors before splitting heads, of shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_attn</span><span class="p">]</span></code> |   are the query vectors before splitting heads, of shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_attn</span><span class="p">]</span></code> | ||||||
| @ -630,10 +639,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">183</span>    <span class="k">def</span> <span class="nf">flash_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> |             <div class="highlight"><pre><span class="lineno">183</span>    <span class="k">def</span> <span class="nf">flash_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-40'> |     <div class='section' id='section-41'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-40'>#</a> |                 <a href='#section-41'>#</a> | ||||||
|             </div> |             </div> | ||||||
|              |              | ||||||
|         </div> |         </div> | ||||||
| @ -641,10 +650,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">190</span>        <span class="nb">print</span><span class="p">(</span><span class="s1">'flash'</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">190</span>        <span class="nb">print</span><span class="p">(</span><span class="s1">'flash'</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-41'> |     <div class='section' id='section-42'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-41'>#</a> |                 <a href='#section-42'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <b>MarkdownException</b> + Italic: not ending with * |             <b>MarkdownException</b> + Italic: not ending with * | ||||||
|         </div> |         </div> | ||||||
| @ -652,10 +661,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">193</span>        <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span></pre></div> |             <div class="highlight"><pre><span class="lineno">193</span>        <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-42'> |     <div class='section' id='section-43'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-42'>#</a> |                 <a href='#section-43'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Stack <code  class="highlight"><span></span><span class="n">q</span></code> |             <p>Stack <code  class="highlight"><span></span><span class="n">q</span></code> | ||||||
| , <code  class="highlight"><span></span><span class="n">k</span></code> | , <code  class="highlight"><span></span><span class="n">k</span></code> | ||||||
| @ -668,10 +677,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">197</span>        <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">197</span>        <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-43'> |     <div class='section' id='section-44'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-43'>#</a> |                 <a href='#section-44'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Split the heads </p> |             <p>Split the heads </p> | ||||||
|  |  | ||||||
| @ -680,10 +689,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">199</span>        <span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">199</span>        <span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-44'> |     <div class='section' id='section-45'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-44'>#</a> |                 <a href='#section-45'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Flash attention works for head sizes <code  class="highlight"><span></span><span class="mi">32</span></code> |             <p>Flash attention works for head sizes <code  class="highlight"><span></span><span class="mi">32</span></code> | ||||||
| , <code  class="highlight"><span></span><span class="mi">64</span></code> | , <code  class="highlight"><span></span><span class="mi">64</span></code> | ||||||
| @ -702,10 +711,10 @@ | |||||||
| <span class="lineno">210</span>            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Head size $</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="si">}</span><span class="s1"> too large for Flash Attention'</span><span class="p">)</span></pre></div> | <span class="lineno">210</span>            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Head size $</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="si">}</span><span class="s1"> too large for Flash Attention'</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-45'> |     <div class='section' id='section-46'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-45'>#</a> |                 <a href='#section-46'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Pad the heads </p> |             <p>Pad the heads </p> | ||||||
|  |  | ||||||
| @ -715,10 +724,10 @@ | |||||||
| <span class="lineno">214</span>            <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">pad</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> | <span class="lineno">214</span>            <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">pad</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-46'> |     <div class='section' id='section-47'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-46'>#</a> |                 <a href='#section-47'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V' |             <b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V' | ||||||
|         </div> |         </div> | ||||||
| @ -726,10 +735,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">219</span>        <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">(</span><span class="n">qkv</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">219</span>        <span class="n">out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span><span class="p">(</span><span class="n">qkv</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-47'> |     <div class='section' id='section-48'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-47'>#</a> |                 <a href='#section-48'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Truncate the extra head size </p> |             <p>Truncate the extra head size </p> | ||||||
|  |  | ||||||
| @ -738,10 +747,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">221</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">]</span></pre></div> |             <div class="highlight"><pre><span class="lineno">221</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">]</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-48'> |     <div class='section' id='section-49'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-48'>#</a> |                 <a href='#section-49'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Reshape to <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code> |             <p>Reshape to <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code> | ||||||
|  </p> |  </p> | ||||||
| @ -751,10 +760,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">223</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">223</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-49'> |     <div class='section' id='section-50'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-49'>#</a> |                 <a href='#section-50'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Map to <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code> |             <p>Map to <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code> | ||||||
|  with a linear layer </p> |  with a linear layer </p> | ||||||
| @ -764,10 +773,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">226</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">226</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-50'> |     <div class='section' id='section-51'> | ||||||
|         <div class='docs doc-strings'> |         <div class='docs doc-strings'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-50'>#</a> |                 <a href='#section-51'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <ul><li><code  class="highlight"><span></span><span class="n">q</span></code> |             <ul><li><code  class="highlight"><span></span><span class="n">q</span></code> | ||||||
|   are the query vectors before splitting heads, of shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_attn</span><span class="p">]</span></code> |   are the query vectors before splitting heads, of shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_attn</span><span class="p">]</span></code> | ||||||
| @ -784,10 +793,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">228</span>    <span class="k">def</span> <span class="nf">normal_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> |             <div class="highlight"><pre><span class="lineno">228</span>    <span class="k">def</span> <span class="nf">normal_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-51'> |     <div class='section' id='section-52'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-51'>#</a> |                 <a href='#section-52'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Split them to heads of shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code> |             <p>Split them to heads of shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">]</span></code> | ||||||
|  </p> |  </p> | ||||||
| @ -799,10 +808,10 @@ | |||||||
| <span class="lineno">238</span>        <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> | <span class="lineno">238</span>        <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-52'> |     <div class='section' id='section-53'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-52'>#</a> |                 <a href='#section-53'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <b>KeyError</b> + '\\frac{Q K^\\top}{\\sqrt{d_{key}}}' |             <b>KeyError</b> + '\\frac{Q K^\\top}{\\sqrt{d_{key}}}' | ||||||
|         </div> |         </div> | ||||||
| @ -810,10 +819,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">241</span>        <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bihd,bjhd->bhij'</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div> |             <div class="highlight"><pre><span class="lineno">241</span>        <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bihd,bjhd->bhij'</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-53'> |     <div class='section' id='section-54'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-53'>#</a> |                 <a href='#section-54'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)' |             <b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)' | ||||||
|         </div> |         </div> | ||||||
| @ -826,10 +835,10 @@ | |||||||
| <span class="lineno">250</span>            <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> | <span class="lineno">250</span>            <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-54'> |     <div class='section' id='section-55'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-54'>#</a> |                 <a href='#section-55'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V' |             <b>KeyError</b> + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V' | ||||||
|         </div> |         </div> | ||||||
| @ -837,10 +846,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">254</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhij,bjhd->bihd'</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">254</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhij,bjhd->bihd'</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-55'> |     <div class='section' id='section-56'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-55'>#</a> |                 <a href='#section-56'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Reshape to <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code> |             <p>Reshape to <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_head</span><span class="p">]</span></code> | ||||||
|  </p> |  </p> | ||||||
| @ -850,10 +859,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">256</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">256</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-56'> |     <div class='section' id='section-57'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-56'>#</a> |                 <a href='#section-57'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <p>Map to <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code> |             <p>Map to <code  class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">height</span> <span class="o">*</span> <span class="n">width</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code> | ||||||
|  with a linear layer </p> |  with a linear layer </p> | ||||||
| @ -863,10 +872,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">258</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">258</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_out</span><span class="p">(</span><span class="n">out</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-57'> |     <div class='section' id='section-58'> | ||||||
|         <div class='docs doc-strings'> |         <div class='docs doc-strings'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-57'>#</a> |                 <a href='#section-58'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <h3>Feed-Forward Network</h3> |             <h3>Feed-Forward Network</h3> | ||||||
|  |  | ||||||
| @ -875,10 +884,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">261</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div> |             <div class="highlight"><pre><span class="lineno">261</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-58'> |     <div class='section' id='section-59'> | ||||||
|         <div class='docs doc-strings'> |         <div class='docs doc-strings'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-58'>#</a> |                 <a href='#section-59'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <ul><li><code  class="highlight"><span></span><span class="n">d_model</span></code> |             <ul><li><code  class="highlight"><span></span><span class="n">d_model</span></code> | ||||||
|   is the input embedding size </li> |   is the input embedding size </li> | ||||||
| @ -890,10 +899,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">266</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_mult</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">):</span></pre></div> |             <div class="highlight"><pre><span class="lineno">266</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_mult</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">):</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-59'> |     <div class='section' id='section-60'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-59'>#</a> |                 <a href='#section-60'>#</a> | ||||||
|             </div> |             </div> | ||||||
|              |              | ||||||
|         </div> |         </div> | ||||||
| @ -906,10 +915,10 @@ | |||||||
| <span class="lineno">276</span>        <span class="p">)</span></pre></div> | <span class="lineno">276</span>        <span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-60'> |     <div class='section' id='section-61'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-60'>#</a> |                 <a href='#section-61'>#</a> | ||||||
|             </div> |             </div> | ||||||
|              |              | ||||||
|         </div> |         </div> | ||||||
| @ -918,10 +927,10 @@ | |||||||
| <span class="lineno">279</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> | <span class="lineno">279</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-61'> |     <div class='section' id='section-62'> | ||||||
|         <div class='docs doc-strings'> |         <div class='docs doc-strings'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-61'>#</a> |                 <a href='#section-62'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)' |             <b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)' | ||||||
|         </div> |         </div> | ||||||
| @ -929,10 +938,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">282</span><span class="k">class</span> <span class="nc">GeGLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div> |             <div class="highlight"><pre><span class="lineno">282</span><span class="k">class</span> <span class="nc">GeGLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-62'> |     <div class='section' id='section-63'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-62'>#</a> |                 <a href='#section-63'>#</a> | ||||||
|             </div> |             </div> | ||||||
|              |              | ||||||
|         </div> |         </div> | ||||||
| @ -941,10 +950,10 @@ | |||||||
| <span class="lineno">290</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> | <span class="lineno">290</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-63'> |     <div class='section' id='section-64'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-63'>#</a> |                 <a href='#section-64'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <b>KeyError</b> + 'xW + b' |             <b>KeyError</b> + 'xW + b' | ||||||
|         </div> |         </div> | ||||||
| @ -952,10 +961,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">292</span>        <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">292</span>        <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_in</span><span class="p">,</span> <span class="n">d_out</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-64'> |     <div class='section' id='section-65'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-64'>#</a> |                 <a href='#section-65'>#</a> | ||||||
|             </div> |             </div> | ||||||
|              |              | ||||||
|         </div> |         </div> | ||||||
| @ -963,10 +972,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">294</span>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> |             <div class="highlight"><pre><span class="lineno">294</span>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-65'> |     <div class='section' id='section-66'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-65'>#</a> |                 <a href='#section-66'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <b>KeyError</b> + 'xW + b' |             <b>KeyError</b> + 'xW + b' | ||||||
|         </div> |         </div> | ||||||
| @ -974,10 +983,10 @@ | |||||||
|             <div class="highlight"><pre><span class="lineno">296</span>        <span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> |             <div class="highlight"><pre><span class="lineno">296</span>        <span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|     <div class='section' id='section-66'> |     <div class='section' id='section-67'> | ||||||
|         <div class='docs'> |         <div class='docs'> | ||||||
|             <div class='section-link'> |             <div class='section-link'> | ||||||
|                 <a href='#section-66'>#</a> |                 <a href='#section-67'>#</a> | ||||||
|             </div> |             </div> | ||||||
|             <b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)' |             <b>KeyError</b> + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)' | ||||||
|         </div> |         </div> | ||||||
|  | |||||||
| @ -164,8 +164,10 @@ class CrossAttention(nn.Module): | |||||||
|         :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` |         :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|  |         has_cond = cond is not None | ||||||
|  |  | ||||||
|         # If `cond` is `None` we perform self attention |         # If `cond` is `None` we perform self attention | ||||||
|         if cond is None: |         if not has_cond: | ||||||
|             cond = x |             cond = x | ||||||
|  |  | ||||||
|         # Get query, key and value vectors |         # Get query, key and value vectors | ||||||
| @ -173,9 +175,7 @@ class CrossAttention(nn.Module): | |||||||
|         k = self.to_k(cond) |         k = self.to_k(cond) | ||||||
|         v = self.to_v(cond) |         v = self.to_v(cond) | ||||||
|  |  | ||||||
|         print('use flash', CrossAttention.use_flash_attention, self.flash) |         if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128: | ||||||
|  |  | ||||||
|         if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: |  | ||||||
|             return self.flash_attention(q, k, v) |             return self.flash_attention(q, k, v) | ||||||
|         else: |         else: | ||||||
|             return self.normal_attention(q, k, v) |             return self.normal_attention(q, k, v) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri