fix value pe double rotation

This commit is contained in:
Varuna Jayasiri
2024-06-20 12:53:09 +05:30
parent 2236f6383c
commit 09d09379c2
2 changed files with 2 additions and 2 deletions

View File

@ -412,7 +412,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>

View File

@ -231,7 +231,7 @@ class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):
# Multiply by values
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# Rotate in the opposite direction so that each embedding hold the relative positions
x = self.value_reverse_rotary_pe(x)