mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
docs
This commit is contained in:
@ -447,7 +447,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">n_tests</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">n_tests</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-29'>
|
||||
@ -496,7 +496,8 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">157</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
|
||||
<span class="lineno">158</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-33'>
|
||||
@ -508,8 +509,8 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">165</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_loop</span><span class="o">.</span><span class="n">idx</span> <span class="o"><</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="lineno">166</span> <span class="k">return</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">166</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_loop</span><span class="o">.</span><span class="n">idx</span> <span class="o"><</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="lineno">167</span> <span class="k">return</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-34'>
|
||||
@ -521,7 +522,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ArithmeticDataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">170</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ArithmeticDataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-35'>
|
||||
@ -533,7 +534,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">qa</span> <span class="o">=</span> <span class="p">[</span><span class="n">dataset</span><span class="o">.</span><span class="n">get_qa</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_tests</span><span class="p">)]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">qa</span> <span class="o">=</span> <span class="p">[</span><span class="n">dataset</span><span class="o">.</span><span class="n">get_qa</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_tests</span><span class="p">)]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-36'>
|
||||
@ -545,7 +546,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">questions</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">qa</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">174</span> <span class="n">questions</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">qa</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-37'>
|
||||
@ -557,7 +558,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">questions</span><span class="p">]])</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">177</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">questions</span><span class="p">]])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-38'>
|
||||
@ -569,7 +570,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">178</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-39'>
|
||||
@ -581,7 +582,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">finished</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">questions</span><span class="p">),))</span><span class="o">.</span><span class="n">bool</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">finished</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">questions</span><span class="p">),))</span><span class="o">.</span><span class="n">bool</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-40'>
|
||||
@ -593,7 +594,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">new_line</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="s1">'</span><span class="se">\n</span><span class="s1">'</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">new_line</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="s1">'</span><span class="se">\n</span><span class="s1">'</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-41'>
|
||||
@ -605,7 +606,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">186</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">questions</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">187</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">questions</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-42'>
|
||||
@ -617,7 +618,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">189</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">'Sample'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">190</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">'Sample'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-43'>
|
||||
@ -629,8 +630,8 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">191</span> <span class="k">if</span> <span class="n">finished</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">finished</span><span class="p">):</span>
|
||||
<span class="lineno">192</span> <span class="k">continue</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">192</span> <span class="k">if</span> <span class="n">finished</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">finished</span><span class="p">):</span>
|
||||
<span class="lineno">193</span> <span class="k">continue</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-44'>
|
||||
@ -642,7 +643,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">output</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">output</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-45'>
|
||||
@ -654,7 +655,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">198</span> <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-46'>
|
||||
@ -666,7 +667,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">finished</span> <span class="o">=</span> <span class="n">finished</span> <span class="o">|</span> <span class="p">(</span><span class="n">output</span> <span class="o">==</span> <span class="n">new_line</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">finished</span> <span class="o">=</span> <span class="n">finished</span> <span class="o">|</span> <span class="p">(</span><span class="n">output</span> <span class="o">==</span> <span class="n">new_line</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-47'>
|
||||
@ -678,8 +679,8 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">202</span> <span class="k">if</span> <span class="n">finished</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">finished</span><span class="p">):</span>
|
||||
<span class="lineno">203</span> <span class="k">continue</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">if</span> <span class="n">finished</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">finished</span><span class="p">):</span>
|
||||
<span class="lineno">204</span> <span class="k">continue</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-48'>
|
||||
@ -691,9 +692,9 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">206</span> <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">questions</span><span class="p">):</span>
|
||||
<span class="lineno">207</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="o">></span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="lineno">208</span> <span class="n">output</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">207</span> <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">questions</span><span class="p">):</span>
|
||||
<span class="lineno">208</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="o">></span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="lineno">209</span> <span class="n">output</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-49'>
|
||||
@ -705,7 +706,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">211</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">data</span><span class="p">,</span> <span class="n">output</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">212</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">data</span><span class="p">,</span> <span class="n">output</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-50'>
|
||||
@ -717,8 +718,8 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">214</span> <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">output</span><span class="p">):</span>
|
||||
<span class="lineno">215</span> <span class="n">results</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">c</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">215</span> <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">output</span><span class="p">):</span>
|
||||
<span class="lineno">216</span> <span class="n">results</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">c</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-51'>
|
||||
@ -730,7 +731,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">218</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'</span><span class="se">\n</span><span class="s1">'</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">results</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'</span><span class="se">\n</span><span class="s1">'</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">results</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-52'>
|
||||
@ -742,8 +743,8 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">res_sample</span> <span class="o">=</span> <span class="n">results</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">';'</span><span class="p">)</span>
|
||||
<span class="lineno">222</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">([(</span><span class="n">res_sample</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">Text</span><span class="o">.</span><span class="n">key</span><span class="p">),</span> <span class="p">(</span><span class="s1">';'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="p">(</span><span class="s1">';'</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">res_sample</span><span class="p">[</span><span class="mi">1</span><span class="p">:]),</span> <span class="n">Text</span><span class="o">.</span><span class="n">none</span><span class="p">)])</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">222</span> <span class="n">res_sample</span> <span class="o">=</span> <span class="n">results</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">';'</span><span class="p">)</span>
|
||||
<span class="lineno">223</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">([(</span><span class="n">res_sample</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">Text</span><span class="o">.</span><span class="n">key</span><span class="p">),</span> <span class="p">(</span><span class="s1">';'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="p">(</span><span class="s1">';'</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">res_sample</span><span class="p">[</span><span class="mi">1</span><span class="p">:]),</span> <span class="n">Text</span><span class="o">.</span><span class="n">none</span><span class="p">)])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-53'>
|
||||
@ -755,7 +756,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">225</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'x=='</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">results</span><span class="p">]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'x=='</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">results</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-54'>
|
||||
@ -767,10 +768,10 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="lineno">229</span> <span class="k">for</span> <span class="n">r</span><span class="p">,</span> <span class="n">_qa</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">results</span><span class="p">,</span> <span class="n">qa</span><span class="p">):</span>
|
||||
<span class="lineno">230</span> <span class="k">if</span> <span class="n">r</span> <span class="o">==</span> <span class="n">_qa</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="lineno">231</span> <span class="n">correct</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">229</span> <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="lineno">230</span> <span class="k">for</span> <span class="n">r</span><span class="p">,</span> <span class="n">_qa</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">results</span><span class="p">,</span> <span class="n">qa</span><span class="p">):</span>
|
||||
<span class="lineno">231</span> <span class="k">if</span> <span class="n">r</span> <span class="o">==</span> <span class="n">_qa</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="lineno">232</span> <span class="n">correct</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-55'>
|
||||
@ -782,7 +783,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'score'</span><span class="p">,</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">results</span><span class="p">))</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'score'</span><span class="p">,</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">results</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-56'>
|
||||
@ -794,8 +795,8 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">237</span><span class="nd">@option</span><span class="p">(</span><span class="n">ArithmeticAutoregression</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
|
||||
<span class="lineno">238</span><span class="k">def</span> <span class="nf">arithmetic_train_loader</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">ArithmeticAutoregression</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">238</span><span class="nd">@option</span><span class="p">(</span><span class="n">ArithmeticAutoregression</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
|
||||
<span class="lineno">239</span><span class="k">def</span> <span class="nf">arithmetic_train_loader</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">ArithmeticAutoregression</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-57'>
|
||||
@ -806,10 +807,10 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">242</span> <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">ArithmeticDataset</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">max_digits</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">train_sequences_per_epoch</span><span class="p">),</span>
|
||||
<span class="lineno">243</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="lineno">244</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">transpose_batch</span><span class="p">,</span>
|
||||
<span class="lineno">245</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">243</span> <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">ArithmeticDataset</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">max_digits</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">train_sequences_per_epoch</span><span class="p">),</span>
|
||||
<span class="lineno">244</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="lineno">245</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">transpose_batch</span><span class="p">,</span>
|
||||
<span class="lineno">246</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-58'>
|
||||
@ -821,7 +822,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">248</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">249</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-59'>
|
||||
@ -832,9 +833,9 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ArithmeticDataset</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
|
||||
<span class="lineno">253</span>
|
||||
<span class="lineno">254</span> <span class="nb">print</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">get_packed_math_input</span><span class="p">()))</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">253</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ArithmeticDataset</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
|
||||
<span class="lineno">254</span>
|
||||
<span class="lineno">255</span> <span class="nb">print</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">get_packed_math_input</span><span class="p">()))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-60'>
|
||||
@ -846,8 +847,8 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">258</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">259</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">259</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">260</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
|
@ -370,7 +370,7 @@
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">rope_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">):</span>
|
||||
<span class="lineno">204</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||||
<span class="lineno">204</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
|
@ -163,7 +163,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"roper_addition"</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s2">"rotary value 8"</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">'screen'</span><span class="p">,</span> <span class="s1">'labml'</span><span class="p">,</span> <span class="s1">'comet'</span><span class="p">})</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"roper_addition"</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s2">"rotary value 7"</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">'screen'</span><span class="p">,</span> <span class="s1">'labml'</span><span class="p">,</span> <span class="s1">'comet'</span><span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
@ -188,7 +188,7 @@
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
|
||||
<span class="lineno">50</span> <span class="s1">'max_digits'</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span></pre></div>
|
||||
<span class="lineno">50</span> <span class="s1">'max_digits'</span><span class="p">:</span> <span class="mi">7</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
@ -296,12 +296,12 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Use <a href="../../optimizers/noam.html">Noam optimizer</a> </p>
|
||||
<p>Use <a href="../../optimizers/noam.html">Adam optimizer</a> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">78</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Noam'</span><span class="p">,</span>
|
||||
<span class="lineno">79</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">1.</span><span class="p">,</span>
|
||||
<div class="highlight"><pre><span class="lineno">78</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span>
|
||||
<span class="lineno">79</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
|
||||
<span class="lineno">80</span> <span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -116,7 +116,7 @@
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">26</span><span class="k">def</span> <span class="nf">_rotary_value_pe_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
|
||||
<span class="lineno">27</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.rope.value_pe</span> <span class="kn">import</span> <span class="n">RotaryValuePEMultiHeadAttention</span>
|
||||
<span class="lineno">28</span> <span class="k">return</span> <span class="n">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span></pre></div>
|
||||
<span class="lineno">28</span> <span class="k">return</span> <span class="n">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
@ -153,7 +153,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"rotary_pe_transformer"</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s2">"rotary_value 1.0, 0.5"</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">'screen'</span><span class="p">,</span> <span class="s1">'labml'</span><span class="p">})</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"rotary_shakespeare"</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s2">"rotary value"</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">'screen'</span><span class="p">,</span> <span class="s1">'labml'</span><span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
@ -286,7 +286,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="s1">'seq_len'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="s1">'seq_len'</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
@ -298,7 +298,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">67</span> <span class="s1">'epochs'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">67</span> <span class="s1">'epochs'</span><span class="p">:</span> <span class="mi">24</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
@ -310,7 +310,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">69</span> <span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">69</span> <span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
@ -322,7 +322,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="s1">'inner_iterations'</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="s1">'inner_iterations'</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
@ -334,9 +334,9 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="s1">'d_model'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
|
||||
<span class="lineno">76</span> <span class="s1">'transformer.ffn.d_ff'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span>
|
||||
<span class="lineno">77</span> <span class="s1">'transformer.n_heads'</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="s1">'d_model'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
|
||||
<span class="lineno">76</span> <span class="s1">'transformer.ffn.d_ff'</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
|
||||
<span class="lineno">77</span> <span class="s1">'transformer.n_heads'</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
|
||||
<span class="lineno">78</span> <span class="s1">'transformer.dropout'</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
@ -345,12 +345,12 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>Use <a href="../../optimizers/noam.html">Noam optimizer</a> </p>
|
||||
<p>Use <a href="../../optimizers/noam.html">Adam optimizer</a> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">81</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Noam'</span><span class="p">,</span>
|
||||
<span class="lineno">82</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">1.</span><span class="p">,</span>
|
||||
<div class="highlight"><pre><span class="lineno">81</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span>
|
||||
<span class="lineno">82</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
|
||||
<span class="lineno">83</span>
|
||||
<span class="lineno">84</span> <span class="s1">'dataloader_shuffle_with_replacement'</span><span class="p">:</span> <span class="kc">True</span>
|
||||
<span class="lineno">85</span> <span class="p">})</span></pre></div>
|
||||
|
@ -97,8 +97,7 @@
|
||||
<span class="lineno">119</span>
|
||||
<span class="lineno">120</span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">121</span>
|
||||
<span class="lineno">122</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
|
||||
<span class="lineno">123</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.rope</span> <span class="kn">import</span> <span class="n">RotaryPositionalEmbeddings</span></pre></div>
|
||||
<span class="lineno">122</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.rope</span> <span class="kn">import</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">,</span> <span class="n">RotaryPEMultiHeadAttention</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
@ -111,7 +110,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">126</span><span class="k">class</span> <span class="nc">ReverseRotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">RotaryPositionalEmbeddings</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">125</span><span class="k">class</span> <span class="nc">ReverseRotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">RotaryPositionalEmbeddings</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
@ -125,7 +124,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">132</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
@ -137,7 +136,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">138</span> <span class="bp">self</span><span class="o">.</span><span class="n">_build_cache</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">137</span> <span class="bp">self</span><span class="o">.</span><span class="n">_build_cache</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
@ -149,7 +148,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d</span><span class="p">],</span> <span class="n">x</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d</span><span class="p">:]</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d</span><span class="p">],</span> <span class="n">x</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d</span><span class="p">:]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
@ -161,7 +160,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">neg_half_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_neg_half</span><span class="p">(</span><span class="n">x_rope</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">neg_half_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_neg_half</span><span class="p">(</span><span class="n">x_rope</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
@ -174,7 +173,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">x_rope</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_rope</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span> <span class="o">-</span> <span class="p">(</span><span class="n">neg_half_x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">x_rope</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_rope</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span> <span class="o">-</span> <span class="p">(</span><span class="n">neg_half_x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
@ -186,7 +185,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
@ -199,7 +198,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">167</span><span class="k">class</span> <span class="nc">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">MultiHeadAttention</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">166</span><span class="k">class</span> <span class="nc">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">RotaryPEMultiHeadAttention</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
@ -210,10 +209,10 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||||
<span class="lineno">175</span> <span class="n">rope_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">rope_value_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
|
||||
<span class="lineno">176</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">):</span>
|
||||
<span class="lineno">177</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||||
<span class="lineno">174</span> <span class="n">rope_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">rope_value_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
|
||||
<span class="lineno">175</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">):</span>
|
||||
<span class="lineno">176</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">rope_percentage</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
@ -225,13 +224,10 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">180</span> <span class="n">d_rope</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_percentage</span><span class="p">)</span>
|
||||
<span class="lineno">181</span> <span class="n">d_rope_value</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_value_percentage</span><span class="p">)</span>
|
||||
<span class="lineno">182</span>
|
||||
<span class="lineno">183</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span>
|
||||
<span class="lineno">184</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope</span><span class="p">)</span>
|
||||
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope_value</span><span class="p">)</span>
|
||||
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_reverse_rotary_pe</span> <span class="o">=</span> <span class="n">ReverseRotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope_value</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">d_rope_value</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_value_percentage</span><span class="p">)</span>
|
||||
<span class="lineno">180</span>
|
||||
<span class="lineno">181</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope_value</span><span class="p">)</span>
|
||||
<span class="lineno">182</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_reverse_rotary_pe</span> <span class="o">=</span> <span class="n">ReverseRotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">d_rope_value</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
@ -239,30 +235,6 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<h3>Calculate scores between queries and keys</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">188</span> <span class="k">def</span> <span class="nf">get_scores</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Calculate dot-product with RoPE </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">194</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'ibhd,jbhd->ijbh'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_rotary_pe</span><span class="p">(</span><span class="n">query</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_rotary_pe</span><span class="p">(</span><span class="n">key</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p> <code class="highlight"><span></span><span class="n">query</span></code>
|
||||
, <code class="highlight"><span></span><span class="n">key</span></code>
|
||||
and <code class="highlight"><span></span><span class="n">value</span></code>
|
||||
@ -278,17 +250,17 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||||
<span class="lineno">197</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||||
<span class="lineno">198</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||||
<span class="lineno">199</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||||
<span class="lineno">200</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||||
<span class="lineno">185</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||||
<span class="lineno">186</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||||
<span class="lineno">187</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||||
<span class="lineno">188</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p><code class="highlight"><span></span><span class="n">query</span></code>
|
||||
, <code class="highlight"><span></span><span class="n">key</span></code>
|
||||
@ -298,16 +270,16 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">212</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="lineno">213</span>
|
||||
<span class="lineno">214</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">215</span> <span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="lineno">201</span>
|
||||
<span class="lineno">202</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">203</span> <span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p>Prepare <code class="highlight"><span></span><span class="n">query</span></code>
|
||||
, <code class="highlight"><span></span><span class="n">key</span></code>
|
||||
@ -317,28 +289,28 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
|
||||
<span class="lineno">220</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
|
||||
<span class="lineno">221</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
|
||||
<span class="lineno">208</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
|
||||
<span class="lineno">209</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Compute attention scores <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.043548em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqw" style=""><span class="mord mathnormal" style="">Q</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">⊤</span></span></span></span></span></span></span></span></span></span></span></span>. This gives a tensor of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">]</span></code>
|
||||
. </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">225</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">213</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Scale scores <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.633028em;vertical-align:-0.538em;"></span><span class="mord coloredeq eqq" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
|
||||
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
|
||||
@ -355,26 +327,26 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">scores</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">216</span> <span class="n">scores</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Apply mask </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">231</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">232</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">))</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">219</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">220</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqy" style=""><span class="mord mathnormal" style="">so</span><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="mord mathnormal" style="">t</span><span class="mord mathnormal" style="">ma</span><span class="mord mathnormal" style="">x</span></span></span></span></span> attention along the key sequence dimension <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop" style=""><span class="mord coloredeq eqy" style=""><span class="mord mathnormal" style="">so</span><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="mord mathnormal" style="">t</span><span class="mord mathnormal" style="">ma</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">(</span></span></span><span class="mord coloredeq eqq" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
|
||||
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
|
||||
@ -391,7 +363,31 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">236</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">224</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Apply dropout </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Rotate value embeddings before taking the weighted sum so that they contain positional information </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
@ -399,30 +395,6 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Apply dropout </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">239</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Rotate value embeddings before taking the weighted sum so that they contain positional information </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">242</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>Multiply by values <span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop" style=""><span class="mord coloredeq eqy" style=""><span class="mord mathnormal" style="">so</span><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="mord mathnormal" style="">t</span><span class="mord mathnormal" style="">ma</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">(</span></span></span><span class="mord coloredeq eqq" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.5261079999999998em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord" style=""><span class="mord mathnormal" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
|
||||
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
|
||||
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
|
||||
@ -438,7 +410,31 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">246</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"ijbh,jbhd->ibhd"</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">))</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"ijbh,jbhd->ibhd"</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Rotate in the opposite direction so that each embedding hold the relative positions </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">237</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_reverse_rotary_pe</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>Save attentions for any other calculations </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">240</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
@ -446,11 +442,11 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>Rotate in the opposite direction so that each embedding hold the relative positions </p>
|
||||
<p>Concatenate multiple heads </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">249</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_reverse_rotary_pe</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">243</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-24'>
|
||||
@ -458,35 +454,11 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>Save attentions for any other calculations </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">252</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-25'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Concatenate multiple heads </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">255</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-26'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<p>Output layer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">258</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">246</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
|
Reference in New Issue
Block a user