sampling links

This commit is contained in:
Varuna Jayasiri
2022-08-08 12:27:11 +05:30
parent f3189e2331
commit 4cf1d74e6d
12 changed files with 84 additions and 55 deletions

View File

@ -149,6 +149,11 @@
<ul><li><a href="uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li></ul>
<h4><a href="activations/index.html">Activations</a></h4>
<ul><li><a href="activations/fta/index.html">Fuzzy Tiling Activations</a></li></ul>
<h4><a href="sampling/index.html">Sampling Techniques</a></h4>
<ul><li><a href="sampling/greedy.html">Greedy Sampling</a> </li>
<li><a href="sampling/temperature.html">Temperature Sampling</a> </li>
<li><a href="sampling/top_k.html">Top-k Sampling</a> </li>
<li><a href="sampling/nucleus.html">Nucleus Sampling</a></li></ul>
<h2>Highlighted Research Paper PDFs</h2>
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>

View File

@ -76,12 +76,13 @@
</div>
<h1>Greedy Sampling</h1>
<p>Here we sample the most likely token from the distribution of logits.</p>
<p>Here&#x27;s an <a href="experiment.html">experiment</a> that uses these sampling techniques.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">12</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">13</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -92,7 +93,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">17</span><span class="k">class</span> <span class="nc">GreedySampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">19</span><span class="k">class</span> <span class="nc">GreedySampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -104,7 +105,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">20</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -115,7 +116,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span> <span class="k">return</span> <span class="n">logits</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">24</span> <span class="k">return</span> <span class="n">logits</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -81,13 +81,14 @@
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.541535em;vertical-align:-1.49153em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.75857em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mrel mtight"></span><span class="mord mtight coloredeq eqd" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.22222em">V</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8220357142857143em;"><span style="top:-2.8220357142857138em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5357142857142856em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mathnormal mtight" style="">p</span></span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op"></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.49153em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqb" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">:</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span></span></p>
<p>That is, we pick the highest probable tokens until the sum of their probabilities is less that <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span>.</p>
<p>Then we sample from the selected tokens.</p>
<p>Here&#x27;s an <a href="experiment.html">experiment</a> that uses these sampling techniques.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">29</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
<div class="highlight"><pre><span class="lineno">29</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">31</span>
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -99,7 +100,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="k">class</span> <span class="nc">NucleusSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">NucleusSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -114,7 +115,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">39</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -125,8 +126,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>
<span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
<div class="highlight"><pre><span class="lineno">44</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>
<span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -138,7 +139,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -150,7 +151,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -162,7 +163,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">probs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">55</span> <span class="n">probs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -174,7 +175,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">sorted_probs</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">sorted_probs</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -186,7 +187,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">cum_sum_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">cum_sum_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -198,7 +199,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">cum_sum_probs</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span></pre></div>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">cum_sum_probs</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -210,7 +211,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">nucleus</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">nucleus</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)),</span> <span class="n">nucleus</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">nucleus</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">nucleus</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)),</span> <span class="n">nucleus</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -222,8 +223,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">sorted_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">)</span>
<span class="lineno">67</span> <span class="n">sorted_log_probs</span><span class="p">[</span><span class="o">~</span><span class="n">nucleus</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">sorted_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">)</span>
<span class="lineno">69</span> <span class="n">sorted_log_probs</span><span class="p">[</span><span class="o">~</span><span class="n">nucleus</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -235,7 +236,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">sampled_sorted_indexes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">sorted_log_probs</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">sampled_sorted_indexes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">sorted_log_probs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -247,7 +248,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">res</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">sampled_sorted_indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">res</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">sampled_sorted_indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -259,7 +260,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">return</span> <span class="n">res</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">return</span> <span class="n">res</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>

File diff suppressed because one or more lines are too long

View File

@ -76,12 +76,13 @@
</div>
<h1>Top-k Sampling</h1>
<p>Here we first pick the top-k tokens from the distribution of logits, and then sample from them.</p>
<p>Here&#x27;s an <a href="experiment.html">experiment</a> that uses these sampling techniques.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">14</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">16</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -93,7 +94,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">TopKSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">20</span><span class="k">class</span> <span class="nc">TopKSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -110,7 +111,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">24</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -121,8 +122,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">30</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span> <span class="o">=</span> <span class="n">k</span>
<span class="lineno">31</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
<div class="highlight"><pre><span class="lineno">32</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span> <span class="o">=</span> <span class="n">k</span>
<span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -134,7 +135,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">35</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -146,7 +147,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="n">zeros</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">40</span> <span class="n">zeros</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -158,7 +159,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="n">values</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">values</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -170,7 +171,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="n">zeros</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">zeros</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -182,7 +183,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">zeros</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">48</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">zeros</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -316,49 +316,49 @@
<url>
<loc>https://nn.labml.ai/sampling/experiment_tiny.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/greedy.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/index.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/top_k.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/temperature.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/experiment.html</loc>
<lastmod>2022-05-07T16:30:00+00:00</lastmod>
<lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/nucleus.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<lastmod>2022-08-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>

View File

@ -119,6 +119,12 @@ Solving games with incomplete information such as poker with CFR.
* [Fuzzy Tiling Activations](activations/fta/index.html)
#### ✨ [Sampling Techniques](sampling/index.html)
* [Greedy Sampling](sampling/greedy.html)
* [Temperature Sampling](sampling/temperature.html)
* [Top-k Sampling](sampling/top_k.html)
* [Nucleus Sampling](sampling/nucleus.html)
## Highlighted Research Paper PDFs
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)

View File

@ -7,6 +7,8 @@ summary: A PyTorch implementation of greedy sampling from language models.
# Greedy Sampling
Here we sample the most likely token from the distribution of logits.
Here's an [experiment](experiment.html) that uses these sampling techniques.
"""
import torch

View File

@ -22,6 +22,8 @@ $$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$
That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$.
Then we sample from the selected tokens.
Here's an [experiment](experiment.html) that uses these sampling techniques.
"""
import torch

View File

@ -12,6 +12,8 @@ $u_{1:|V|}$ are the logits of the distribution and T is the temperature:
$$P(x_i=V_l | x_{1:i-1}) = \frac{\exp(\frac{u_l}{T})}{\sum_j \exp(\frac{u_j}{T})}$$
$T = 1$ is normal random sampling.
Here's an [experiment](experiment.html) that uses these sampling techniques.
"""
import torch

View File

@ -8,6 +8,8 @@ summary: A PyTorch implementation of top-k sampling from language models.
Here we first pick the top-k tokens from the distribution of logits, and then
sample from them.
Here's an [experiment](experiment.html) that uses these sampling techniques.
"""
import torch

View File

@ -123,6 +123,12 @@ Solving games with incomplete information such as poker with CFR.
* [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html)
#### ✨ [Sampling Techniques](https://nn.labml.ai/sampling/index.html)
* [Greedy Sampling](https://nn.labml.ai/sampling/greedy.html)
* [Temperature Sampling](https://nn.labml.ai/sampling/temperature.html)
* [Top-k Sampling](https://nn.labml.ai/sampling/top_k.html)
* [Nucleus Sampling](https://nn.labml.ai/sampling/nucleus.html)
## Highlighted Research Paper PDFs