Here we first pick the top-k tokens from the distribution of logits, and then sample from them.
Here's an experiment that uses these sampling techniques.
15import torch
16
17from labml_nn.sampling import Sampler20class TopKSampler(Sampler):k
  is the number of tokens to pick sampler
  is the sampler to use for the top-k tokenssampler
 can be any sampler that takes a logits tensor as input and returns a token tensor;  e.g. `TemperatureSampler'.
24    def __init__(self, k: int, sampler: Sampler):32        self.k = k
33        self.sampler = samplerSample from logits
35    def __call__(self, logits: torch.Tensor):New logits filled with ; i.e. zero probability
40        zeros = logits.new_ones(logits.shape) * float('-inf')Pick the largest logits and their indices
42        values, indices = torch.topk(logits, self.k, dim=-1)Set the values of the top-k selected indices to actual logits. Logits of other tokens remain
45        zeros.scatter_(-1, indices, values)Sample from the top-k logits with the specified sampler.
48        return self.sampler(zeros)