mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 10:51:23 +08:00
79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
"""
|
|
---
|
|
title: Nucleus Sampling
|
|
summary: A PyTorch implementation of nucleus sampling from language models.
|
|
---
|
|
|
|
# Nucleus Sampling
|
|
|
|
This is an implementation of nucleus sampling, introduced in the paper
|
|
[The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751).
|
|
|
|
The paper discusses the problems with other sampling methods such as Beam Search,
|
|
[Pure sampling](temperature.html), [Temperature sampling](temperature.html), and
|
|
[Top-k sampling](top_k.html). The paper introduces the idea of nucleus sampling,
|
|
which practically performs better than other sampling methods for text generation.
|
|
|
|
Nucleus sampling first picks a subset of the vocabulary $V^{(p)} \subset V$,
|
|
where $V^{(p)}$ is smallest set of tokens such that
|
|
|
|
$$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$
|
|
|
|
That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$.
|
|
|
|
Then we sample from the selected tokens.
|
|
|
|
Here's an [experiment](experiment.html) that uses these sampling techniques.
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from labml_nn.sampling import Sampler
|
|
|
|
|
|
class NucleusSampler(Sampler):
|
|
"""
|
|
## Nucleus Sampler
|
|
"""
|
|
def __init__(self, p: float, sampler: Sampler):
|
|
"""
|
|
:param p: is the sum of probabilities of tokens to pick $p$
|
|
:param sampler: is the sampler to use for the selected tokens
|
|
"""
|
|
self.p = p
|
|
self.sampler = sampler
|
|
# Softmax to compute $P(x_i | x_{1:i-1})$ from the logits
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
def __call__(self, logits: torch.Tensor):
|
|
"""
|
|
Sample from logits with Nucleus Sampling
|
|
"""
|
|
|
|
# Get probabilities $P(x_i | x_{1:i-1})$
|
|
probs = self.softmax(logits)
|
|
|
|
# Sort probabilities in descending order
|
|
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
|
|
# Get the cumulative sum of probabilities in the sorted order
|
|
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
|
|
# Find the cumulative sums less than $p$.
|
|
nucleus = cum_sum_probs < self.p
|
|
# Prepend ones so that we add one token after the minimum number
|
|
# of tokens with cumulative probability less that $p$.
|
|
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
|
|
|
|
# Get log probabilities and mask out the non-nucleus
|
|
sorted_log_probs = torch.log(sorted_probs)
|
|
sorted_log_probs[~nucleus] = float('-inf')
|
|
|
|
# Sample from the sampler
|
|
sampled_sorted_indexes = self.sampler(sorted_log_probs)
|
|
|
|
# Get the actual indexes
|
|
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
|
|
|
|
#
|
|
return res.squeeze(-1)
|