mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
"""
|
|
---
|
|
title: Top-k Sampling
|
|
summary: A PyTorch implementation of top-k sampling from language models.
|
|
---
|
|
|
|
# Top-k Sampling
|
|
|
|
Here we first pick the top-k tokens from the distribution of logits, and then
|
|
sample from them.
|
|
|
|
Here's an [experiment](experiment.html) that uses these sampling techniques.
|
|
"""
|
|
|
|
import torch
|
|
|
|
from labml_nn.sampling import Sampler
|
|
|
|
|
|
class TopKSampler(Sampler):
|
|
"""
|
|
## Top-k Sampler
|
|
"""
|
|
def __init__(self, k: int, sampler: Sampler):
|
|
"""
|
|
:param k: is the number of tokens to pick
|
|
:param sampler: is the sampler to use for the top-k tokens
|
|
|
|
`sampler` can be any sampler that takes a logits tensor as input and returns a token tensor;
|
|
e.g. [`TemperatureSampler'](temperature.html).
|
|
"""
|
|
self.k = k
|
|
self.sampler = sampler
|
|
|
|
def __call__(self, logits: torch.Tensor):
|
|
"""
|
|
Sample from logits
|
|
"""
|
|
# New logits filled with $-\infty$; i.e. zero probability
|
|
zeros = logits.new_ones(logits.shape) * float('-inf')
|
|
# Pick the largest $k$ logits and their indices
|
|
values, indices = torch.topk(logits, self.k, dim=-1)
|
|
# Set the values of the top-k selected indices to actual logits.
|
|
# Logits of other tokens remain $-\infty$
|
|
zeros.scatter_(-1, indices, values)
|
|
|
|
# Sample from the top-k logits with the specified sampler.
|
|
return self.sampler(zeros)
|