Sampling Techniques (#139)

This commit is contained in:
Varuna Jayasiri
2022-08-08 11:12:32 +05:30
committed by GitHub
parent 940b3c01fc
commit f3189e2331
24 changed files with 2358 additions and 37 deletions

View File

@ -253,10 +253,14 @@
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<ul><li><code>n_tokens</code> is the number of tokens in the vocabulary </li>
<li><code>d_model</code> is the embedding size </li>
<li><code>n_layers</code> is the number of transformer layers </li>
<li><code>layer</code> is the layer. We use <code class="highlight"><span></span><span class="n">n_layers</span></code>
<ul><li><code class="highlight"><span></span><span class="n">n_tokens</span></code>
is the number of tokens in the vocabulary </li>
<li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the embedding size </li>
<li><code class="highlight"><span></span><span class="n">n_layers</span></code>
is the number of transformer layers </li>
<li><code class="highlight"><span></span><span class="n">layer</span></code>
is the layer. We use <code class="highlight"><span></span><span class="n">n_layers</span></code>
copies of this for the transformer.</li></ul>
</div>
@ -329,7 +333,8 @@
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<ul><li><code>x</code> are the input tokens of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
are the input tokens of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
</li></ul>
</div>

View File

@ -120,10 +120,14 @@
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code>lower_limit</code> is the lower limit <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span></span></span></span> </li>
<li><code>upper_limit</code> is the upper limit <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span></span></span></span> </li>
<li><code>delta</code> is the bin size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span> </li>
<li><code>eta</code> is the parameter <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> that detemines the softness of the boundaries.</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">lower_limit</span></code>
is the lower limit <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">upper_limit</span></code>
is the upper limit <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">delta</span></code>
is the bin size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">eta</span></code>
is the parameter <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> that detemines the softness of the boundaries.</li></ul>
</div>
<div class='code'>

View File

@ -110,7 +110,8 @@
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code>seq_len</code> is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">seq_len</span></code>
is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch</li></ul>
</div>
<div class='code'>

View File

@ -112,10 +112,14 @@
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code>n_tokens</code> is the number of tokens in the vocabulary </li>
<li><code>d_model</code> is the embedding size </li>
<li><code>n_layers</code> is the number of transformer layers </li>
<li><code>layer</code> is the layer. We use <code class="highlight"><span></span><span class="n">n_layers</span></code>
<ul><li><code class="highlight"><span></span><span class="n">n_tokens</span></code>
is the number of tokens in the vocabulary </li>
<li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the embedding size </li>
<li><code class="highlight"><span></span><span class="n">n_layers</span></code>
is the number of transformer layers </li>
<li><code class="highlight"><span></span><span class="n">layer</span></code>
is the layer. We use <code class="highlight"><span></span><span class="n">n_layers</span></code>
copies of this for the tranformer.</li></ul>
</div>
@ -176,7 +180,8 @@
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<ul><li><code>x</code> are the input tokens of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
are the input tokens of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
</li></ul>
</div>

View File

@ -124,10 +124,14 @@
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code>alpha</code> is <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span></span></span></span> </li>
<li><code>normalized_shape</code> is the shape for LayerNorm <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mop" style=""><span class="mord mathnormal" style="">L</span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span></span></span></span></span></span></span> </li>
<li><code>eps</code> is <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">ϵ</span></span></span></span> for LayerNorm </li>
<li><code>elementwise_affine</code> is a flag indicating whether to do an elementwise transformation in LayerNorm</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">alpha</span></code>
is <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">normalized_shape</span></code>
is the shape for LayerNorm <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mop" style=""><span class="mord mathnormal" style="">L</span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span></span></span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">eps</span></code>
is <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">ϵ</span></span></span></span> for LayerNorm </li>
<li><code class="highlight"><span></span><span class="n">elementwise_affine</span></code>
is a flag indicating whether to do an elementwise transformation in LayerNorm</li></ul>
</div>
<div class='code'>
@ -166,8 +170,10 @@
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<ul><li><code>x</code> is the output from the previous layer <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </li>
<li><code>gx</code> is the output of the current sub-layer <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop"><span class="mop mathnormal" style="position:relative;top:0.091665em;">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></li></ul>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the output from the previous layer <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">gx</span></code>
is the output of the current sub-layer <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop"><span class="mop mathnormal" style="position:relative;top:0.091665em;">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></li></ul>
</div>
<div class='code'>
@ -204,11 +210,16 @@
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<ul><li><code>d_model</code> is the token embedding size </li>
<li><code>self_attn</code> is the self attention module </li>
<li><code>feed_forward</code> is the feed forward module </li>
<li><code>deep_norm_alpha</code> is <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span></span></span></span> coefficient in DeepNorm </li>
<li><code>deep_norm_beta</code> is <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span></span></span></span></span> constant for scaling weights initialization</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the token embedding size </li>
<li><code class="highlight"><span></span><span class="n">self_attn</span></code>
is the self attention module </li>
<li><code class="highlight"><span></span><span class="n">feed_forward</span></code>
is the feed forward module </li>
<li><code class="highlight"><span></span><span class="n">deep_norm_alpha</span></code>
is <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span></span></span></span> coefficient in DeepNorm </li>
<li><code class="highlight"><span></span><span class="n">deep_norm_beta</span></code>
is <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span></span></span></span></span> constant for scaling weights initialization</li></ul>
</div>
<div class='code'>
@ -314,7 +325,8 @@
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<ul><li><code>x</code> are the embeddings of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
are the embeddings of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</li></ul>
</div>

View File

@ -93,6 +93,9 @@
"1904.09237": [
"https://nn.labml.ai/optimizers/amsgrad.html"
],
"1904.09751": [
"https://nn.labml.ai/sampling/nucleus.html"
],
"1908.03265": [
"https://nn.labml.ai/optimizers/radam.html"
],

View File

@ -148,6 +148,10 @@ div.section div.code pre {
white-space: pre-wrap;
}
.highlight .n, .highlight .nn, .highlight .nc, .highlight .nf {
cursor: pointer;
}
code {
padding: 0.2rem 0.5rem;
margin: 0 0.2rem;

View File

@ -0,0 +1,423 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="We try out different sampling techniques for language models on HuggingFace&#x27;s GPT2 model."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Trying out Sampling Techniques for Language Models"/>
<meta name="twitter:description" content="We try out different sampling techniques for language models on HuggingFace&#x27;s GPT2 model."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/sampling/experiment.html"/>
<meta property="og:title" content="Trying out Sampling Techniques for Language Models"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Trying out Sampling Techniques for Language Models"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Trying out Sampling Techniques for Language Models"/>
<meta property="og:description" content="We try out different sampling techniques for language models on HuggingFace&#x27;s GPT2 model."/>
<title>Trying out Sampling Techniques for Language Models</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/sampling/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">sampling</a>
</p>
<p>
<a href="https://github.com/sponsors/labmlai" target="_blank">
<img alt="Sponsor"
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
style="max-width:100%;"/></a>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sampling/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Trying out Sampling Techniques for Language Models</h1>
<ul><li><a href="greedy.html">Greedy Sampling</a> </li>
<li><a href="temperature.html">Temperature Sampling</a> </li>
<li><a href="top_k.html">Top-k Sampling</a> </li>
<li><a href="nucleus.html">Nucleus Sampling</a></li></ul>
<p>This experiment uses the above sampling techniques, on HuggingFace&#x27;s GPT2 model.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">19</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">lab</span>
<span class="lineno">21</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">23</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml_nn.sampling.greedy</span> <span class="kn">import</span> <span class="n">GreedySampler</span>
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">labml_nn.sampling.nucleus</span> <span class="kn">import</span> <span class="n">NucleusSampler</span>
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">labml_nn.sampling.temperature</span> <span class="kn">import</span> <span class="n">TemperatureSampler</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_nn.sampling.top_k</span> <span class="kn">import</span> <span class="n">TopKSampler</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">GPT2Tokenizer</span><span class="p">,</span> <span class="n">GPT2LMHeadModel</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Sample from model</h2>
<ul><li><code class="highlight"><span></span><span class="n">model</span></code>
is the model to sample from </li>
<li><code class="highlight"><span></span><span class="n">tokenizer</span></code>
is the tokenizer to use </li>
<li><code class="highlight"><span></span><span class="n">sampler</span></code>
is the sampler to use </li>
<li><code class="highlight"><span></span><span class="n">n_samples</span></code>
is the number of samples to generate </li>
<li><code class="highlight"><span></span><span class="n">n_tokens</span></code>
is the number of tokens to generate </li>
<li><code class="highlight"><span></span><span class="n">seq_len</span></code>
is the maximum sequence length for the model </li>
<li><code class="highlight"><span></span><span class="n">prompt</span></code>
is the starting prompt</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">33</span><span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">GPT2LMHeadModel</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">:</span> <span class="n">GPT2Tokenizer</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">,</span>
<span class="lineno">34</span> <span class="n">n_samples</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">prompt</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Tokenize the <code class="highlight"><span></span><span class="n">prompt</span></code>
and make <code class="highlight"><span></span><span class="n">n_samples</span></code>
copies of it </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">prompt</span><span class="p">))[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="p">(</span><span class="n">n_samples</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Collect output for printing </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="n">logs</span> <span class="o">=</span> <span class="p">[[(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">meta</span><span class="p">)]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Sample <code class="highlight"><span></span><span class="n">n_tokens</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Sample&#39;</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Truncate the data to the maximum sequence length </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="o">-</span><span class="n">seq_len</span><span class="p">:]</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Get the model output. The &#x27;logits&#x27; has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Get the <code class="highlight"><span></span><span class="n">logits</span></code>
of the last token </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Sample from the <code class="highlight"><span></span><span class="n">logits</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">res</span> <span class="o">=</span> <span class="n">sampler</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Add the sampled token to the data </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">data</span><span class="p">,</span> <span class="n">res</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Decode and add the sampled token for logging </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">):</span>
<span class="lineno">65</span> <span class="n">logs</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="p">[(</span><span class="s1">&#39;&#39;</span> <span class="o">+</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">res</span><span class="p">[</span><span class="n">j</span><span class="p">]),</span> <span class="n">Text</span><span class="o">.</span><span class="n">value</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Print the sampled outputs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">):</span>
<span class="lineno">69</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">logs</span><span class="p">[</span><span class="n">j</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<h3>Try different sampling techniques</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Load the model and tokenizer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Load tokenizer/model&#39;</span><span class="p">):</span>
<span class="lineno">79</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">GPT2Tokenizer</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s1">&#39;gpt2&#39;</span><span class="p">,</span> <span class="n">cache_dir</span><span class="o">=</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;cache&#39;</span><span class="p">)</span>
<span class="lineno">80</span> <span class="n">model</span> <span class="o">=</span> <span class="n">GPT2LMHeadModel</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s1">&#39;gpt2&#39;</span><span class="p">,</span> <span class="n">cache_dir</span><span class="o">=</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;cache&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Set the model to eval mode </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Prompts to use for sampling </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">prompt</span> <span class="o">=</span> <span class="s1">&#39;I saw an interesting dream last night. &#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p><a href="greedy.html">Greedy Sampling</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;greedy&#39;</span><span class="p">):</span>
<span class="lineno">89</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">GreedySampler</span><span class="p">(),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">prompt</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p><a href="temperature.html">Temperature Sampling</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;temperature=1.&#39;</span><span class="p">):</span>
<span class="lineno">93</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">1.</span><span class="p">),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">prompt</span><span class="p">)</span>
<span class="lineno">94</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;temperature=.1&#39;</span><span class="p">):</span>
<span class="lineno">95</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">.1</span><span class="p">),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">prompt</span><span class="p">)</span>
<span class="lineno">96</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;temperature=10.&#39;</span><span class="p">):</span>
<span class="lineno">97</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">10.</span><span class="p">),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">prompt</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p><a href="top_k.html">Top-k Sampling</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;top_k=5&#39;</span><span class="p">):</span>
<span class="lineno">101</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">TopKSampler</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">1.</span><span class="p">)),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">prompt</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p><a href="nucleus.html">Nucleus Sampling</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;nucleus p=.95&#39;</span><span class="p">):</span>
<span class="lineno">105</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">NucleusSampler</span><span class="p">(</span><span class="mf">0.95</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">1.</span><span class="p">)),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">prompt</span><span class="p">)</span>
<span class="lineno">106</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;nucleus p=.1&#39;</span><span class="p">):</span>
<span class="lineno">107</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">NucleusSampler</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">1.</span><span class="p">)),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">prompt</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">111</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -0,0 +1,302 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="experiment_tiny.py"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/sampling/experiment_tiny.html"/>
<meta property="og:title" content="experiment_tiny.py"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="experiment_tiny.py"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="experiment_tiny.py"/>
<meta property="og:description" content=""/>
<title>experiment_tiny.py</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/sampling/experiment_tiny.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">sampling</a>
</p>
<p>
<a href="https://github.com/sponsors/labmlai" target="_blank">
<img alt="Sponsor"
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
style="max-width:100%;"/></a>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sampling/experiment_tiny.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span>
<span class="lineno">2</span>
<span class="lineno">3</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">4</span>
<span class="lineno">5</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">monit</span>
<span class="lineno">6</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">logger</span>
<span class="lineno">7</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">8</span><span class="kn">from</span> <span class="nn">labml_helpers.datasets.text</span> <span class="kn">import</span> <span class="n">TextDataset</span>
<span class="lineno">9</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span>
<span class="lineno">10</span><span class="kn">from</span> <span class="nn">labml_nn.sampling.greedy</span> <span class="kn">import</span> <span class="n">GreedySampler</span>
<span class="lineno">11</span><span class="kn">from</span> <span class="nn">labml_nn.sampling.nucleus</span> <span class="kn">import</span> <span class="n">NucleusSampler</span>
<span class="lineno">12</span><span class="kn">from</span> <span class="nn">labml_nn.sampling.temperature</span> <span class="kn">import</span> <span class="n">TemperatureSampler</span>
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">labml_nn.sampling.top_k</span> <span class="kn">import</span> <span class="n">TopKSampler</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.basic.autoregressive_experiment</span> <span class="kn">import</span> <span class="n">Configs</span><span class="p">,</span> <span class="n">AutoregressiveTransformer</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">17</span><span class="k">def</span> <span class="nf">get_model_dataset</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">AutoregressiveTransformer</span><span class="p">,</span> <span class="n">TextDataset</span><span class="p">]:</span>
<span class="lineno">18</span> <span class="n">experiment</span><span class="o">.</span><span class="n">evaluate</span><span class="p">()</span>
<span class="lineno">19</span>
<span class="lineno">20</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span>
<span class="lineno">21</span>
<span class="lineno">22</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load_configs</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">))</span>
<span class="lineno">23</span>
<span class="lineno">24</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">)</span>
<span class="lineno">25</span>
<span class="lineno">26</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span>
<span class="lineno">27</span>
<span class="lineno">28</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">()</span>
<span class="lineno">29</span>
<span class="lineno">30</span> <span class="k">return</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="n">conf</span><span class="o">.</span><span class="n">text</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ds</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">,</span> <span class="n">n_samples</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">prompt</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="lineno">34</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="lineno">35</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">ds</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">prompt</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">],</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Collect output for printing </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="n">logs</span> <span class="o">=</span> <span class="p">[[(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">meta</span><span class="p">)]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Sample 25 tokens </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Sample&#39;</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Tokenize the prompt </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="o">-</span><span class="n">seq_len</span><span class="p">:]</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Get the model output </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="n">logits</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="lineno">45</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Get the model prediction (greedy) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">res</span> <span class="o">=</span> <span class="n">sampler</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span>
<span class="lineno">48</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">data</span><span class="p">,</span> <span class="n">res</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Add the prediction for logging </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">):</span>
<span class="lineno">51</span> <span class="n">logs</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="p">[(</span><span class="s1">&#39;&#39;</span> <span class="o">+</span> <span class="n">ds</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">res</span><span class="p">[</span><span class="n">j</span><span class="p">]],</span> <span class="n">Text</span><span class="o">.</span><span class="n">value</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Print the sampled output </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">):</span>
<span class="lineno">55</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">logs</span><span class="p">[</span><span class="n">j</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="lineno">59</span> <span class="n">model</span><span class="p">,</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">get_model_dataset</span><span class="p">(</span><span class="s1">&#39;074d4004cc6b11ecad7a0242ac1c0002&#39;</span><span class="p">)</span>
<span class="lineno">60</span> <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="lineno">61</span>
<span class="lineno">62</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;greedy&#39;</span><span class="p">):</span>
<span class="lineno">63</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ds</span><span class="p">,</span> <span class="n">GreedySampler</span><span class="p">(),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;It is&#39;</span><span class="p">)</span>
<span class="lineno">64</span>
<span class="lineno">65</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;temperature=1.&#39;</span><span class="p">):</span>
<span class="lineno">66</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ds</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">1.</span><span class="p">),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;It is&#39;</span><span class="p">)</span>
<span class="lineno">67</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;temperature=.1&#39;</span><span class="p">):</span>
<span class="lineno">68</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ds</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">.1</span><span class="p">),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;It is&#39;</span><span class="p">)</span>
<span class="lineno">69</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;temperature=10.&#39;</span><span class="p">):</span>
<span class="lineno">70</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ds</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">10.</span><span class="p">),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;It is&#39;</span><span class="p">)</span>
<span class="lineno">71</span>
<span class="lineno">72</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;top_k=5&#39;</span><span class="p">):</span>
<span class="lineno">73</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ds</span><span class="p">,</span> <span class="n">TopKSampler</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">1.</span><span class="p">)),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;It is&#39;</span><span class="p">)</span>
<span class="lineno">74</span>
<span class="lineno">75</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;nucles p=.95&#39;</span><span class="p">):</span>
<span class="lineno">76</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ds</span><span class="p">,</span> <span class="n">NucleusSampler</span><span class="p">(</span><span class="mf">0.95</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">1.</span><span class="p">)),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;It is&#39;</span><span class="p">)</span>
<span class="lineno">77</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;nucles p=.95&#39;</span><span class="p">):</span>
<span class="lineno">78</span> <span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ds</span><span class="p">,</span> <span class="n">NucleusSampler</span><span class="p">(</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">TemperatureSampler</span><span class="p">(</span><span class="mf">1.</span><span class="p">)),</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">&#39;It is&#39;</span><span class="p">)</span>
<span class="lineno">79</span>
<span class="lineno">80</span>
<span class="lineno">81</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">82</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

167
docs/sampling/greedy.html Normal file
View File

@ -0,0 +1,167 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation of greedy sampling from language models."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Greedy Sampling"/>
<meta name="twitter:description" content="A PyTorch implementation of greedy sampling from language models."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/sampling/greedy.html"/>
<meta property="og:title" content="Greedy Sampling"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Greedy Sampling"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Greedy Sampling"/>
<meta property="og:description" content="A PyTorch implementation of greedy sampling from language models."/>
<title>Greedy Sampling</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/sampling/greedy.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">sampling</a>
</p>
<p>
<a href="https://github.com/sponsors/labmlai" target="_blank">
<img alt="Sponsor"
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
style="max-width:100%;"/></a>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sampling/greedy.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Greedy Sampling</h1>
<p>Here we sample the most likely token from the distribution of logits.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">12</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">13</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">17</span><span class="k">class</span> <span class="nc">GreedySampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p> Sample the most likely token from the distribution of logits</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span> <span class="k">return</span> <span class="n">logits</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

173
docs/sampling/index.html Normal file
View File

@ -0,0 +1,173 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A set of PyTorch implementations/tutorials of sampling techniques for language models."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Sampling Techniques for Language Models"/>
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials of sampling techniques for language models."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/sampling/index.html"/>
<meta property="og:title" content="Sampling Techniques for Language Models"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Sampling Techniques for Language Models"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Sampling Techniques for Language Models"/>
<meta property="og:description" content="A set of PyTorch implementations/tutorials of sampling techniques for language models."/>
<title>Sampling Techniques for Language Models</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/sampling/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">sampling</a>
</p>
<p>
<a href="https://github.com/sponsors/labmlai" target="_blank">
<img alt="Sponsor"
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
style="max-width:100%;"/></a>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sampling/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Sampling Techniques for Language Models</h1>
<ul><li><a href="greedy.html">Greedy Sampling</a> </li>
<li><a href="temperature.html">Temperature Sampling</a> </li>
<li><a href="top_k.html">Top-k Sampling</a> </li>
<li><a href="nucleus.html">Nucleus Sampling</a></li></ul>
<p>Here&#x27;s an <a href="experiment.html">experiment</a> that uses these sampling techniques.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span></span><span class="kn">import</span> <span class="nn">torch</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>Sampler base class</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">21</span><span class="k">class</span> <span class="nc">Sampler</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h3>Sample from logits</h3>
<ul><li><code class="highlight"><span></span><span class="n">logits</span></code>
are the logits of the distribution of shape <code class="highlight"><span></span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">25</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span> <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

311
docs/sampling/nucleus.html Normal file
View File

@ -0,0 +1,311 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation of nucleus sampling from language models."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Nucleus Sampling"/>
<meta name="twitter:description" content="A PyTorch implementation of nucleus sampling from language models."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/sampling/nucleus.html"/>
<meta property="og:title" content="Nucleus Sampling"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Nucleus Sampling"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Nucleus Sampling"/>
<meta property="og:description" content="A PyTorch implementation of nucleus sampling from language models."/>
<title>Nucleus Sampling</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/sampling/nucleus.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">sampling</a>
</p>
<p>
<a href="https://github.com/sponsors/labmlai" target="_blank">
<img alt="Sponsor"
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
style="max-width:100%;"/></a>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sampling/nucleus.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Nucleus Sampling</h1>
<p>This is an implementation of nucleus sampling, introduced in the paper <a href="https://papers.labml.ai/paper/1904.09751">The Curious Case of Neural Text Degeneration</a>.</p>
<p>The paper discusses the problems with other sampling methods such as Beam Search, <a href="temperature.html">Pure sampling</a>, <a href="temperature.html">Temperature sampling</a>, and <a href="top_k.html">Top-k sampling</a>. The paper introduces the idea of nucleus sampling, which practically performs better than other sampling methods for text generation.</p>
<p>Nucleus sampling first picks a subset of the vocabulary <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.9270999999999999em;vertical-align:-0.0391em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mathnormal mtight" style="">p</span></span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span></span></span></span>, where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8879999999999999em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.22222em">V</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mathnormal mtight" style="">p</span></span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span></span></span></span> is smallest set of tokens such that</p>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.541535em;vertical-align:-1.49153em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.050005em;"><span style="top:-1.75857em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mrel mtight"></span><span class="mord mtight coloredeq eqd" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.22222em">V</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8220357142857143em;"><span style="top:-2.8220357142857138em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5357142857142856em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mopen mtight" style="">(</span><span class="mord mtight coloredeq eqe" style=""><span class="mord mathnormal mtight" style="">p</span></span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.0500049999999996em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op"></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.49153em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqb" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">:</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span></span></p>
<p>That is, we pick the highest probable tokens until the sum of their probabilities is less that <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span>.</p>
<p>Then we sample from the selected tokens.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">29</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Nucleus Sampler</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="k">class</span> <span class="nc">NucleusSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">p</span></code>
is the sum of probabilities of tokens to pick <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">sampler</span></code>
is the sampler to use for the selected tokens</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>
<span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Softmax to compute <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqb" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">:</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span> from the logits </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p> Sample from logits with Nucleus Sampling</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Get probabilities <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqb" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">:</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">probs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Sort probabilities in descending order </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">sorted_probs</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Get the cumulative sum of probabilities in the sorted order </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">cum_sum_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Find the cumulative sums less than <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span>. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">cum_sum_probs</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">p</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Prepend ones so that we add one token after the minimum number of tokens with cumulative probability less that <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span>. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">nucleus</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">nucleus</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">nucleus</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)),</span> <span class="n">nucleus</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Get log probabilities and mask out the non-nucleus </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">sorted_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">)</span>
<span class="lineno">67</span> <span class="n">sorted_log_probs</span><span class="p">[</span><span class="o">~</span><span class="n">nucleus</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Sample from the sampler </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">sampled_sorted_indexes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">sorted_log_probs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Get the actual indexes </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">res</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">sampled_sorted_indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">return</span> <span class="n">res</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

File diff suppressed because one or more lines are too long

234
docs/sampling/top_k.html Normal file
View File

@ -0,0 +1,234 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation of top-k sampling from language models."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Top-k Sampling"/>
<meta name="twitter:description" content="A PyTorch implementation of top-k sampling from language models."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/sampling/top_k.html"/>
<meta property="og:title" content="Top-k Sampling"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Top-k Sampling"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Top-k Sampling"/>
<meta property="og:description" content="A PyTorch implementation of top-k sampling from language models."/>
<title>Top-k Sampling</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/sampling/top_k.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">sampling</a>
</p>
<p>
<a href="https://github.com/sponsors/labmlai" target="_blank">
<img alt="Sponsor"
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
style="max-width:100%;"/></a>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/sampling/top_k.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Top-k Sampling</h1>
<p>Here we first pick the top-k tokens from the distribution of logits, and then sample from them.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">14</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_nn.sampling</span> <span class="kn">import</span> <span class="n">Sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Top-k Sampler</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">TopKSampler</span><span class="p">(</span><span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">k</span></code>
is the number of tokens to pick </li>
<li><code class="highlight"><span></span><span class="n">sampler</span></code>
is the sampler to use for the top-k tokens</li></ul>
<p><code class="highlight"><span></span><span class="n">sampler</span></code>
can be any sampler that takes a logits tensor as input and returns a token tensor; e.g. <a href="temperature.html">`TemperatureSampler&#x27;</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">30</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span> <span class="o">=</span> <span class="n">k</span>
<span class="lineno">31</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">sampler</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p> Sample from logits</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>New logits filled with <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.66666em;vertical-align:-0.08333em;"></span><span class="mord"></span><span class="mord"></span></span></span></span>; i.e. zero probability </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="n">zeros</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Pick the largest <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.03148em;">k</span></span></span></span> logits and their indices </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="n">values</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Set the values of the top-k selected indices to actual logits. Logits of other tokens remain <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.66666em;vertical-align:-0.08333em;"></span><span class="mord"></span><span class="mord"></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="n">zeros</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Sample from the top-k logits with the specified sampler. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">(</span><span class="n">zeros</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -314,6 +314,55 @@
</url>
<url>
<loc>https://nn.labml.ai/sampling/experiment_tiny.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/greedy.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/index.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/top_k.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/temperature.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/experiment.html</loc>
<lastmod>2022-05-07T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sampling/nucleus.html</loc>
<lastmod>2022-07-29T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/index.html</loc>
<lastmod>2022-07-25T16:30:00+00:00</lastmod>

View File

@ -109,8 +109,10 @@
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code>image_path</code> is the path to the images </li>
<li><code>mask_path</code> is the path to the masks</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">image_path</span></code>
is the path to the images </li>
<li><code class="highlight"><span></span><span class="n">mask_path</span></code>
is the path to the masks</li></ul>
</div>
<div class='code'>
@ -174,7 +176,8 @@
<a href='#section-7'>#</a>
</div>
<h4>Get an image and its mask.</h4>
<ul><li><code>idx</code> is index of the image</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">idx</span></code>
is index of the image</li></ul>
</div>
<div class='code'>

View File

@ -107,8 +107,10 @@
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code>in_channels</code> is the number of input channels </li>
<li><code>out_channels</code> is the number of output channels</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
is the number of input channels </li>
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
is the number of output channels</li></ul>
</div>
<div class='code'>
@ -294,8 +296,10 @@
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<ul><li><code>x</code> current feature map in the expansive path </li>
<li><code>contracting_x</code> corresponding feature map from the contracting path</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
current feature map in the expansive path </li>
<li><code class="highlight"><span></span><span class="n">contracting_x</span></code>
corresponding feature map from the contracting path</li></ul>
</div>
<div class='code'>
@ -355,8 +359,10 @@
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<ul><li><code>in_channels</code> number of channels in the input image </li>
<li><code>out_channels</code> number of channels in the result feature map</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
number of channels in the input image </li>
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
number of channels in the result feature map</li></ul>
</div>
<div class='code'>
@ -466,7 +472,8 @@
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<ul><li><code>x</code> input image</li></ul>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
input image</li></ul>
</div>
<div class='code'>

View File

@ -0,0 +1,31 @@
"""
---
title: Sampling Techniques for Language Models
summary: >
A set of PyTorch implementations/tutorials of sampling techniques for language models.
---
# Sampling Techniques for Language Models
* [Greedy Sampling](greedy.html)
* [Temperature Sampling](temperature.html)
* [Top-k Sampling](top_k.html)
* [Nucleus Sampling](nucleus.html)
Here's an [experiment](experiment.html) that uses these sampling techniques.
"""
import torch
class Sampler:
"""
### Sampler base class
"""
def __call__(self, logits: torch.Tensor) -> torch.Tensor:
"""
### Sample from logits
:param logits: are the logits of the distribution of shape `[..., n_tokens]`
"""
raise NotImplementedError()

View File

@ -0,0 +1,111 @@
"""
---
title: Trying out Sampling Techniques for Language Models
summary: >
We try out different sampling techniques for language models on HuggingFace's GPT2 model.
---
# Trying out Sampling Techniques for Language Models
* [Greedy Sampling](greedy.html)
* [Temperature Sampling](temperature.html)
* [Top-k Sampling](top_k.html)
* [Nucleus Sampling](nucleus.html)
This experiment uses the above sampling techniques, on HuggingFace's GPT2 model.
"""
import torch
from labml import monit, logger, lab
from labml.logger import Text
from labml_nn.sampling import Sampler
from labml_nn.sampling.greedy import GreedySampler
from labml_nn.sampling.nucleus import NucleusSampler
from labml_nn.sampling.temperature import TemperatureSampler
from labml_nn.sampling.top_k import TopKSampler
from transformers import GPT2Tokenizer, GPT2LMHeadModel
@torch.no_grad()
def sample(model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer, sampler: Sampler,
n_samples: int, n_tokens: int, seq_len: int, prompt: str):
"""
## Sample from model
:param model: is the model to sample from
:param tokenizer: is the tokenizer to use
:param sampler: is the sampler to use
:param n_samples: is the number of samples to generate
:param n_tokens: is the number of tokens to generate
:param seq_len: is the maximum sequence length for the model
:param prompt: is the starting prompt
"""
# Tokenize the `prompt` and make `n_samples` copies of it
data = torch.tile(torch.tensor(tokenizer.encode(prompt))[None, :], (n_samples, 1))
# Collect output for printing
logs = [[(prompt, Text.meta)] for _ in range(n_samples)]
# Sample `n_tokens`
for i in monit.iterate('Sample', n_tokens):
# Truncate the data to the maximum sequence length
data = data[-seq_len:]
# Get the model output. The 'logits' has shape `[batch_size, seq_len, n_tokens]`
logits = model(data)[0]
# Get the `logits` of the last token
logits = logits[:, -1]
# Sample from the `logits`
res = sampler(logits)
# Add the sampled token to the data
data = torch.cat([data, res[:, None]], dim=1)
# Decode and add the sampled token for logging
for j in range(n_samples):
logs[j] += [('' + tokenizer.decode(res[j]), Text.value)]
# Print the sampled outputs
for j in range(n_samples):
logger.log(logs[j])
def main():
"""
### Try different sampling techniques
"""
# Load the model and tokenizer
with monit.section('Load tokenizer/model'):
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')
model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')
# Set the model to eval mode
model.eval()
# Prompts to use for sampling
prompt = 'I saw an interesting dream last night. '
# [Greedy Sampling](greedy.html)
with monit.section('greedy'):
sample(model, tokenizer, GreedySampler(), 4, 32, 128, prompt)
# [Temperature Sampling](temperature.html)
with monit.section('temperature=1.'):
sample(model, tokenizer, TemperatureSampler(1.), 4, 32, 128, prompt)
with monit.section('temperature=.1'):
sample(model, tokenizer, TemperatureSampler(.1), 4, 32, 128, prompt)
with monit.section('temperature=10.'):
sample(model, tokenizer, TemperatureSampler(10.), 4, 32, 128, prompt)
# [Top-k Sampling](top_k.html)
with monit.section('top_k=5'):
sample(model, tokenizer, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, prompt)
# [Nucleus Sampling](nucleus.html)
with monit.section('nucleus p=.95'):
sample(model, tokenizer, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, prompt)
with monit.section('nucleus p=.1'):
sample(model, tokenizer, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, prompt)
#
if __name__ == '__main__':
main()

View File

@ -0,0 +1,82 @@
from typing import Tuple
import torch
from labml import experiment, monit
from labml import logger
from labml.logger import Text
from labml_helpers.datasets.text import TextDataset
from labml_nn.sampling import Sampler
from labml_nn.sampling.greedy import GreedySampler
from labml_nn.sampling.nucleus import NucleusSampler
from labml_nn.sampling.temperature import TemperatureSampler
from labml_nn.sampling.top_k import TopKSampler
from labml_nn.transformers.basic.autoregressive_experiment import Configs, AutoregressiveTransformer
def get_model_dataset(run_uuid: str) -> Tuple[AutoregressiveTransformer, TextDataset]:
experiment.evaluate()
conf = Configs()
experiment.configs(conf, experiment.load_configs(run_uuid))
experiment.load(run_uuid)
experiment.add_pytorch_models({'model': conf.model})
experiment.start()
return conf.model, conf.text
def sample(model, ds, sampler: Sampler, n_samples: int, n_tokens: int, seq_len: int, prompt: str):
with torch.no_grad():
data = torch.tile(ds.text_to_i(prompt)[:, None], (1, n_samples))
# Collect output for printing
logs = [[(prompt, Text.meta)] for _ in range(n_samples)]
# Sample 25 tokens
for i in monit.iterate('Sample', n_tokens):
# Tokenize the prompt
data = data[-seq_len:]
# Get the model output
logits, *_ = model(data)
logits = logits[-1]
# Get the model prediction (greedy)
res = sampler(logits)
data = torch.cat([data, res[None, :]], dim=0)
# Add the prediction for logging
for j in range(n_samples):
logs[j] += [('' + ds.itos[res[j]], Text.value)]
# Print the sampled output
for j in range(n_samples):
logger.log(logs[j])
def main():
model, ds = get_model_dataset('074d4004cc6b11ecad7a0242ac1c0002')
model.eval()
with monit.section('greedy'):
sample(model, ds, GreedySampler(), 4, 32, 128, 'It is')
with monit.section('temperature=1.'):
sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is')
with monit.section('temperature=.1'):
sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is')
with monit.section('temperature=10.'):
sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is')
with monit.section('top_k=5'):
sample(model, ds, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, 'It is')
with monit.section('nucles p=.95'):
sample(model, ds, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, 'It is')
with monit.section('nucles p=.95'):
sample(model, ds, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, 'It is')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,22 @@
"""
---
title: Greedy Sampling
summary: A PyTorch implementation of greedy sampling from language models.
---
# Greedy Sampling
Here we sample the most likely token from the distribution of logits.
"""
import torch
from labml_nn.sampling import Sampler
class GreedySampler(Sampler):
def __call__(self, logits: torch.Tensor):
"""
Sample the most likely token from the distribution of logits
"""
return logits.argmax(dim=-1)

View File

@ -0,0 +1,76 @@
"""
---
title: Nucleus Sampling
summary: A PyTorch implementation of nucleus sampling from language models.
---
# Nucleus Sampling
This is an implementation of nucleus sampling, introduced in the paper
[The Curious Case of Neural Text Degeneration](https://papers.labml.ai/paper/1904.09751).
The paper discusses the problems with other sampling methods such as Beam Search,
[Pure sampling](temperature.html), [Temperature sampling](temperature.html), and
[Top-k sampling](top_k.html). The paper introduces the idea of nucleus sampling,
which practically performs better than other sampling methods for text generation.
Nucleus sampling first picks a subset of the vocabulary $V^{(p)} \subset V$,
where $V^{(p)}$ is smallest set of tokens such that
$$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$
That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$.
Then we sample from the selected tokens.
"""
import torch
from torch import nn
from labml_nn.sampling import Sampler
class NucleusSampler(Sampler):
"""
## Nucleus Sampler
"""
def __init__(self, p: float, sampler: Sampler):
"""
:param p: is the sum of probabilities of tokens to pick $p$
:param sampler: is the sampler to use for the selected tokens
"""
self.p = p
self.sampler = sampler
# Softmax to compute $P(x_i | x_{1:i-1})$ from the logits
self.softmax = nn.Softmax(dim=-1)
def __call__(self, logits: torch.Tensor):
"""
Sample from logits with Nucleus Sampling
"""
# Get probabilities $P(x_i | x_{1:i-1})$
probs = self.softmax(logits)
# Sort probabilities in descending order
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
# Get the cumulative sum of probabilities in the sorted order
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
# Find the cumulative sums less than $p$.
nucleus = cum_sum_probs < self.p
# Prepend ones so that we add one token after the minimum number
# of tokens with cumulative probability less that $p$.
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
# Get log probabilities and mask out the non-nucleus
sorted_log_probs = torch.log(sorted_probs)
sorted_log_probs[~nucleus] = float('-inf')
# Sample from the sampler
sampled_sorted_indexes = self.sampler(sorted_log_probs)
# Get the actual indexes
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
#
return res.squeeze(-1)

View File

@ -0,0 +1,42 @@
"""
---
title: Sampling from Language Models with Temperature
summary: A PyTorch implementation of sampling from language models with temperature.
---
# Sampling from Language Models with Temperature
Here we sample from the following probability distribution where $V$ is the vocabulary,
$u_{1:|V|}$ are the logits of the distribution and T is the temperature:
$$P(x_i=V_l | x_{1:i-1}) = \frac{\exp(\frac{u_l}{T})}{\sum_j \exp(\frac{u_j}{T})}$$
$T = 1$ is normal random sampling.
"""
import torch
from torch.distributions import Categorical
from labml_nn.sampling import Sampler
class TemperatureSampler(Sampler):
"""
## Sampler with Temperature
"""
def __init__(self, temperature: float = 1.0):
"""
:param temperature: is the temperature to sample with
"""
self.temperature = temperature
def __call__(self, logits: torch.Tensor):
"""
Sample from logits
"""
# Create a categorical distribution with temperature adjusted logits
dist = Categorical(logits=logits / self.temperature)
# Sample
return dist.sample()

View File

@ -0,0 +1,46 @@
"""
---
title: Top-k Sampling
summary: A PyTorch implementation of top-k sampling from language models.
---
# Top-k Sampling
Here we first pick the top-k tokens from the distribution of logits, and then
sample from them.
"""
import torch
from labml_nn.sampling import Sampler
class TopKSampler(Sampler):
"""
## Top-k Sampler
"""
def __init__(self, k: int, sampler: Sampler):
"""
:param k: is the number of tokens to pick
:param sampler: is the sampler to use for the top-k tokens
`sampler` can be any sampler that takes a logits tensor as input and returns a token tensor;
e.g. [`TemperatureSampler'](temperature.html).
"""
self.k = k
self.sampler = sampler
def __call__(self, logits: torch.Tensor):
"""
Sample from logits
"""
# New logits filled with $-\infty$; i.e. zero probability
zeros = logits.new_ones(logits.shape) * float('-inf')
# Pick the largest $k$ logits and their indices
values, indices = torch.topk(logits, self.k, dim=-1)
# Set the values of the top-k selected indices to actual logits.
# Logits of other tokens remain $-\infty$
zeros.scatter_(-1, indices, values)
# Sample from the top-k logits with the specified sampler.
return self.sampler(zeros)