mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			79 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			79 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Nucleus Sampling
 | |
| summary: A PyTorch implementation of nucleus sampling from language models.
 | |
| ---
 | |
| 
 | |
| # Nucleus Sampling
 | |
| 
 | |
| This is an implementation of nucleus sampling, introduced in the paper
 | |
| [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751).
 | |
| 
 | |
| The paper discusses the problems with other sampling methods such as Beam Search,
 | |
| [Pure sampling](temperature.html), [Temperature sampling](temperature.html), and
 | |
| [Top-k sampling](top_k.html). The paper introduces the idea of nucleus sampling,
 | |
| which practically performs better than other sampling methods for text generation.
 | |
| 
 | |
| Nucleus sampling first picks a subset of the vocabulary $V^{(p)} \subset V$,
 | |
| where $V^{(p)}$ is smallest set of tokens such that
 | |
| 
 | |
| $$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$
 | |
| 
 | |
| That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$.
 | |
| 
 | |
| Then we sample from the selected tokens.
 | |
| 
 | |
| Here's an [experiment](experiment.html) that uses these sampling techniques.
 | |
| """
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| 
 | |
| from labml_nn.sampling import Sampler
 | |
| 
 | |
| 
 | |
| class NucleusSampler(Sampler):
 | |
|     """
 | |
|     ## Nucleus Sampler
 | |
|     """
 | |
|     def __init__(self, p: float, sampler: Sampler):
 | |
|         """
 | |
|         :param p: is the sum of probabilities of tokens to pick $p$
 | |
|         :param sampler: is the sampler to use for the selected tokens
 | |
|         """
 | |
|         self.p = p
 | |
|         self.sampler = sampler
 | |
|         # Softmax to compute $P(x_i | x_{1:i-1})$ from the logits
 | |
|         self.softmax = nn.Softmax(dim=-1)
 | |
| 
 | |
|     def __call__(self, logits: torch.Tensor):
 | |
|         """
 | |
|         Sample from logits with Nucleus Sampling
 | |
|         """
 | |
| 
 | |
|         # Get probabilities $P(x_i | x_{1:i-1})$
 | |
|         probs = self.softmax(logits)
 | |
| 
 | |
|         # Sort probabilities in descending order
 | |
|         sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
 | |
|         # Get the cumulative sum of probabilities in the sorted order
 | |
|         cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
 | |
|         # Find the cumulative sums less than $p$.
 | |
|         nucleus = cum_sum_probs < self.p
 | |
|         # Prepend ones so that we add one token after the minimum number
 | |
|         # of tokens with cumulative probability less that $p$.
 | |
|         nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
 | |
| 
 | |
|         # Get log probabilities and mask out the non-nucleus
 | |
|         sorted_log_probs = torch.log(sorted_probs)
 | |
|         sorted_log_probs[~nucleus] = float('-inf')
 | |
| 
 | |
|         # Sample from the sampler
 | |
|         sampled_sorted_indexes = self.sampler(sorted_log_probs)
 | |
| 
 | |
|         # Get the actual indexes
 | |
|         res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
 | |
| 
 | |
|         #
 | |
|         return res.squeeze(-1)
 | 
