sampling links

This commit is contained in:
Varuna Jayasiri
2022-08-08 12:27:11 +05:30
parent f3189e2331
commit 4cf1d74e6d
12 changed files with 84 additions and 55 deletions

View File

@ -81,13 +81,14 @@
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.541535em;vertical-align:-1.49153em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.75857em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mrel mtight"></span><span class="mord mtight coloredeq eqd" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.22222em">V</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8220357142857143em;"><span style="top:-2.8220357142857138em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5357142857142856em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mathnormal mtight" style="">p</span></span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op"></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.49153em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqb" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">:</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span></span></p>
<p>That is, we pick the highest probable tokens until the sum of their probabilities is less that <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span>.</p>
<p>Then we sample from the selected tokens.</p>
<p>Here&#x27;s an <a href="experiment.html">experiment</a> that uses these sampling techniques.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">29</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
<div class="highlight"><pre><span class="lineno">29</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">31</span>
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -99,7 +100,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="k">class</span> <span class="nc">NucleusSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">NucleusSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -114,7 +115,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">39</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -125,8 +126,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>
<span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
<div class="highlight"><pre><span class="lineno">44</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>
<span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -138,7 +139,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -150,7 +151,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -162,7 +163,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">probs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">55</span> <span class="n">probs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -174,7 +175,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">sorted_probs</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">sorted_probs</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -186,7 +187,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">cum_sum_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">cum_sum_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -198,7 +199,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">cum_sum_probs</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span></pre></div>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">cum_sum_probs</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -210,7 +211,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">nucleus</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">nucleus</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)),</span> <span class="n">nucleus</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">nucleus</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">nucleus</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)),</span> <span class="n">nucleus</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -222,8 +223,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">sorted_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">)</span>
<span class="lineno">67</span> <span class="n">sorted_log_probs</span><span class="p">[</span><span class="o">~</span><span class="n">nucleus</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">sorted_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">)</span>
<span class="lineno">69</span> <span class="n">sorted_log_probs</span><span class="p">[</span><span class="o">~</span><span class="n">nucleus</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -235,7 +236,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">sampled_sorted_indexes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">sorted_log_probs</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">sampled_sorted_indexes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">sorted_log_probs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -247,7 +248,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">res</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">sampled_sorted_indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">res</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">sampled_sorted_indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -259,7 +260,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">return</span> <span class="n">res</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">return</span> <span class="n">res</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>