Files
Varuna Jayasiri 9a42ac2697 arxiv.org links
2023-10-24 14:42:32 +01:00

79 lines
2.7 KiB
Python

"""
---
title: Nucleus Sampling
summary: A PyTorch implementation of nucleus sampling from language models.
---
# Nucleus Sampling
This is an implementation of nucleus sampling, introduced in the paper
[The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751).
The paper discusses the problems with other sampling methods such as Beam Search,
[Pure sampling](temperature.html), [Temperature sampling](temperature.html), and
[Top-k sampling](top_k.html). The paper introduces the idea of nucleus sampling,
which practically performs better than other sampling methods for text generation.
Nucleus sampling first picks a subset of the vocabulary $V^{(p)} \subset V$,
where $V^{(p)}$ is smallest set of tokens such that
$$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$
That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$.
Then we sample from the selected tokens.
Here's an [experiment](experiment.html) that uses these sampling techniques.
"""
import torch
from torch import nn
from labml_nn.sampling import Sampler
class NucleusSampler(Sampler):
"""
## Nucleus Sampler
"""
def __init__(self, p: float, sampler: Sampler):
"""
:param p: is the sum of probabilities of tokens to pick $p$
:param sampler: is the sampler to use for the selected tokens
"""
self.p = p
self.sampler = sampler
# Softmax to compute $P(x_i | x_{1:i-1})$ from the logits
self.softmax = nn.Softmax(dim=-1)
def __call__(self, logits: torch.Tensor):
"""
Sample from logits with Nucleus Sampling
"""
# Get probabilities $P(x_i | x_{1:i-1})$
probs = self.softmax(logits)
# Sort probabilities in descending order
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
# Get the cumulative sum of probabilities in the sorted order
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
# Find the cumulative sums less than $p$.
nucleus = cum_sum_probs < self.p
# Prepend ones so that we add one token after the minimum number
# of tokens with cumulative probability less that $p$.
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
# Get log probabilities and mask out the non-nucleus
sorted_log_probs = torch.log(sorted_probs)
sorted_log_probs[~nucleus] = float('-inf')
# Sample from the sampler
sampled_sorted_indexes = self.sampler(sorted_log_probs)
# Get the actual indexes
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
#
return res.squeeze(-1)