This commit is contained in:
Varuna Jayasiri
2022-06-03 10:13:03 +05:30
parent 669b920d6a
commit a450afd1bd
4 changed files with 68 additions and 66 deletions

View File

@ -235,7 +235,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span> <span class="o">=</span> <span class="n">idx_theta2</span><span class="o">.</span><span class="n">cos</span><span class="p">()[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <div class="highlight"><pre><span class="lineno">156</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span> <span class="o">=</span> <span class="n">idx_theta2</span><span class="o">.</span><span class="n">cos</span><span class="p">()[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">157</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span> <span class="o">=</span> <span class="n">idx_theta2</span><span class="o">.</span><span class="n">cos</span><span class="p">()[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span></pre></div> <span class="lineno">157</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span> <span class="o">=</span> <span class="n">idx_theta2</span><span class="o">.</span><span class="n">sin</span><span class="p">()[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-12'> <div class='section' id='section-12'>
@ -320,7 +320,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">177</span> <span class="n">neg_half_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_neg_half</span><span class="p">(</span><span class="n">x_rope</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">178</span> <span class="n">neg_half_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_neg_half</span><span class="p">(</span><span class="n">x_rope</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-19'> <div class='section' id='section-19'>
@ -333,7 +333,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">x_rope</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_rope</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span> <span class="o">+</span> <span class="p">(</span><span class="n">neg_half_x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span></pre></div> <div class="highlight"><pre><span class="lineno">190</span> <span class="n">x_rope</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_rope</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span> <span class="o">+</span> <span class="p">(</span><span class="n">neg_half_x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-20'> <div class='section' id='section-20'>
@ -345,7 +345,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">192</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">193</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-21'> <div class='section' id='section-21'>
@ -358,7 +358,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">195</span><span class="k">class</span> <span class="nc">RotaryPEMultiHeadAttention</span><span class="p">(</span><span class="n">MultiHeadAttention</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">196</span><span class="k">class</span> <span class="nc">RotaryPEMultiHeadAttention</span><span class="p">(</span><span class="n">MultiHeadAttention</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-22'> <div class='section' id='section-22'>
@ -369,7 +369,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">rope_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">203</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">rope_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-23'> <div class='section' id='section-23'>
@ -382,7 +382,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">207</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-24'> <div class='section' id='section-24'>
@ -394,9 +394,9 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="n">d_rope</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_percentage</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">210</span> <span class="n">d_rope</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_percentage</span><span class="p">)</span>
<span class="lineno">210</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span> <span class="lineno">211</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span>
<span class="lineno">211</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span></pre></div> <span class="lineno">212</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-25'> <div class='section' id='section-25'>
@ -408,7 +408,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="k">def</span> <span class="nf">get_scores</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">214</span> <span class="k">def</span> <span class="nf">get_scores</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-26'> <div class='section' id='section-26'>
@ -420,7 +420,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ibhd,jbhd-&gt;ijbh&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span><span class="p">(</span><span class="n">query</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span><span class="p">(</span><span class="n">key</span><span class="p">))</span></pre></div> <div class="highlight"><pre><span class="lineno">220</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ibhd,jbhd-&gt;ijbh&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span><span class="p">(</span><span class="n">query</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span><span class="p">(</span><span class="n">key</span><span class="p">))</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-27'> <div class='section' id='section-27'>
@ -432,7 +432,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">222</span><span class="k">def</span> <span class="nf">_test_rotary</span><span class="p">():</span></pre></div> <div class="highlight"><pre><span class="lineno">223</span><span class="k">def</span> <span class="nf">_test_rotary</span><span class="p">():</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-28'> <div class='section' id='section-28'>
@ -443,16 +443,16 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">7</span><span class="p">],</span> <span class="p">[</span><span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">10</span><span class="p">]],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">227</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">7</span><span class="p">],</span> <span class="p">[</span><span class="mi">7</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">10</span><span class="p">]],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
<span class="lineno">227</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="lineno">228</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">228</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="lineno">229</span> <span class="n">inspect</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">229</span> <span class="lineno">230</span>
<span class="lineno">230</span> <span class="n">rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="lineno">231</span> <span class="n">rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="lineno">231</span> <span class="n">inspect</span><span class="p">(</span><span class="n">rotary_pe</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="lineno">232</span> <span class="n">inspect</span><span class="p">(</span><span class="n">rotary_pe</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="lineno">232</span>
<span class="lineno">233</span> <span class="lineno">233</span>
<span class="lineno">234</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span> <span class="lineno">234</span>
<span class="lineno">235</span> <span class="n">_test_rotary</span><span class="p">()</span></pre></div> <span class="lineno">235</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">236</span> <span class="n">_test_rotary</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='footer'> <div class='footer'>

View File

@ -160,7 +160,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">neg_half_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_neg_half</span><span class="p">(</span><span class="n">x_rope</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">143</span> <span class="n">neg_half_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_neg_half</span><span class="p">(</span><span class="n">x_rope</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
@ -173,7 +173,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">x_rope</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_rope</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span> <span class="o">-</span> <span class="p">(</span><span class="n">neg_half_x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span></pre></div> <div class="highlight"><pre><span class="lineno">159</span> <span class="n">x_rope</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_rope</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span> <span class="o">-</span> <span class="p">(</span><span class="n">neg_half_x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>
@ -185,7 +185,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">162</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-8'> <div class='section' id='section-8'>
@ -198,7 +198,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">164</span><span class="k">class</span> <span class="nc">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">MultiHeadAttention</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">165</span><span class="k">class</span> <span class="nc">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">MultiHeadAttention</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-9'> <div class='section' id='section-9'>
@ -209,9 +209,9 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <div class="highlight"><pre><span class="lineno">172</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">172</span> <span class="n">rope_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">rope_value_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="lineno">173</span> <span class="n">rope_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">rope_value_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
<span class="lineno">173</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span></pre></div> <span class="lineno">174</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-10'> <div class='section' id='section-10'>
@ -224,7 +224,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">177</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">178</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-11'> <div class='section' id='section-11'>
@ -236,13 +236,13 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">180</span> <span class="n">d_rope</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_percentage</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">181</span> <span class="n">d_rope</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_percentage</span><span class="p">)</span>
<span class="lineno">181</span> <span class="n">d_rope_value</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_value_percentage</span><span class="p">)</span> <span class="lineno">182</span> <span class="n">d_rope_value</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_value_percentage</span><span class="p">)</span>
<span class="lineno">182</span> <span class="lineno">183</span>
<span class="lineno">183</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span> <span class="lineno">184</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span>
<span class="lineno">184</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span> <span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span>
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope_value</span><span class="p">)</span> <span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope_value</span><span class="p">)</span>
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_reverse_rotary_pe</span> <span class="o">=</span> <span class="n">ReverseRotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope_value</span><span class="p">)</span></pre></div> <span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_reverse_rotary_pe</span> <span class="o">=</span> <span class="n">ReverseRotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope_value</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-12'> <div class='section' id='section-12'>
@ -254,7 +254,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="k">def</span> <span class="nf">get_scores</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">189</span> <span class="k">def</span> <span class="nf">get_scores</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-13'> <div class='section' id='section-13'>
@ -266,7 +266,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ibhd,jbhd-&gt;ijbh&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span><span class="p">(</span><span class="n">query</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span><span class="p">(</span><span class="n">key</span><span class="p">))</span></pre></div> <div class="highlight"><pre><span class="lineno">195</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ibhd,jbhd-&gt;ijbh&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span><span class="p">(</span><span class="n">query</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span><span class="p">(</span><span class="n">key</span><span class="p">))</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-14'> <div class='section' id='section-14'>
@ -289,11 +289,11 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <div class="highlight"><pre><span class="lineno">197</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">197</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="lineno">198</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">198</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="lineno">199</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">199</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="lineno">200</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">200</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div> <span class="lineno">201</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-15'> <div class='section' id='section-15'>
@ -309,10 +309,10 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span> <div class="highlight"><pre><span class="lineno">213</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span>
<span class="lineno">213</span> <span class="lineno">214</span>
<span class="lineno">214</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="lineno">215</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">215</span> <span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div> <span class="lineno">216</span> <span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-16'> <div class='section' id='section-16'>
@ -328,9 +328,9 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">220</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
<span class="lineno">220</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="lineno">221</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="lineno">221</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div> <span class="lineno">222</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-17'> <div class='section' id='section-17'>
@ -343,7 +343,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">225</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">226</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-18'> <div class='section' id='section-18'>
@ -366,7 +366,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">scores</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div> <div class="highlight"><pre><span class="lineno">229</span> <span class="n">scores</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-19'> <div class='section' id='section-19'>
@ -378,8 +378,8 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">231</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <div class="highlight"><pre><span class="lineno">232</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">232</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">))</span></pre></div> <span class="lineno">233</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">))</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-20'> <div class='section' id='section-20'>
@ -402,7 +402,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">236</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">237</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-21'> <div class='section' id='section-21'>
@ -414,7 +414,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">240</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-22'> <div class='section' id='section-22'>
@ -426,7 +426,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">242</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">243</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-23'> <div class='section' id='section-23'>
@ -449,7 +449,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">))</span></pre></div> <div class="highlight"><pre><span class="lineno">247</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">))</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-24'> <div class='section' id='section-24'>
@ -461,7 +461,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">249</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_reverse_rotary_pe</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">250</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_reverse_rotary_pe</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-25'> <div class='section' id='section-25'>
@ -473,7 +473,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">252</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div> <div class="highlight"><pre><span class="lineno">253</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-26'> <div class='section' id='section-26'>
@ -485,7 +485,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">255</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">256</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-27'> <div class='section' id='section-27'>
@ -497,7 +497,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">258</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">259</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='footer'> <div class='footer'>

View File

@ -154,7 +154,7 @@ class RotaryPositionalEmbeddings(nn.Module):
# Cache them # Cache them
self.cos_cached = idx_theta2.cos()[:, None, None, :] self.cos_cached = idx_theta2.cos()[:, None, None, :]
self.sin_cached = idx_theta2.cos()[:, None, None, :] self.sin_cached = idx_theta2.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor): def _neg_half(self, x: torch.Tensor):
# $\frac{d}{2}$ # $\frac{d}{2}$
@ -173,7 +173,8 @@ class RotaryPositionalEmbeddings(nn.Module):
# Split the features, we can choose to apply rotary embeddings only to a partial set of features. # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., :self.d], x[..., self.d:] x_rope, x_pass = x[..., :self.d], x[..., self.d:]
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ # Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope) neg_half_x = self._neg_half(x_rope)
# Calculate # Calculate

View File

@ -138,7 +138,8 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
# Split the features, we can choose to apply rotary embeddings only to a partial set of features. # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., :self.d], x[..., self.d:] x_rope, x_pass = x[..., :self.d], x[..., self.d:]
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ # Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope) neg_half_x = self._neg_half(x_rope)
# Calculate # Calculate