Sampling Techniques (#139)

This commit is contained in:
Varuna Jayasiri
2022-08-08 11:12:32 +05:30
committed by GitHub
parent 940b3c01fc
commit f3189e2331
24 changed files with 2358 additions and 37 deletions

View File

@ -0,0 +1,46 @@
"""
---
title: Top-k Sampling
summary: A PyTorch implementation of top-k sampling from language models.
---
# Top-k Sampling
Here we first pick the top-k tokens from the distribution of logits, and then
sample from them.
"""
import torch
from labml_nn.sampling import Sampler
class TopKSampler(Sampler):
"""
## Top-k Sampler
"""
def __init__(self, k: int, sampler: Sampler):
"""
:param k: is the number of tokens to pick
:param sampler: is the sampler to use for the top-k tokens
`sampler` can be any sampler that takes a logits tensor as input and returns a token tensor;
e.g. [`TemperatureSampler'](temperature.html).
"""
self.k = k
self.sampler = sampler
def __call__(self, logits: torch.Tensor):
"""
Sample from logits
"""
# New logits filled with $-\infty$; i.e. zero probability
zeros = logits.new_ones(logits.shape) * float('-inf')
# Pick the largest $k$ logits and their indices
values, indices = torch.topk(logits, self.k, dim=-1)
# Set the values of the top-k selected indices to actual logits.
# Logits of other tokens remain $-\infty$
zeros.scatter_(-1, indices, values)
# Sample from the top-k logits with the specified sampler.
return self.sampler(zeros)