15import torch
16
17from labml_nn.sampling import Sampler20class TopKSampler(Sampler):k
選択するトークンの数ですsampler
トップkのトークンに使用するサンプラーですsampler
ロジッツテンソルを入力として受け取り、トークンテンソルを返すサンプラーならどれでもかまいません(例:`TemperatureSampler')。
24    def __init__(self, k: int, sampler: Sampler):32        self.k = k
33        self.sampler = samplerロジットからのサンプル
35    def __call__(self, logits: torch.Tensor):新しいロジットを埋める、つまり確率がゼロ
40        zeros = logits.new_ones(logits.shape) * float('-inf')最大のロジットとそのインデックスを選択してください
42        values, indices = torch.topk(logits, self.k, dim=-1)選択した上位kのインデックスの値を実際のロジットに設定します。他のトークンのロジットは残ります
45        zeros.scatter_(-1, indices, values)指定されたサンプラーを使用して、上からk個のロジットをサンプリングします。
48        return self.sampler(zeros)