ここでは、次の確率分布からサンプリングします。ここで、はボキャブラリー、は分布のロジット、Tは温度です。
は通常のランダムサンプリングです。
19import torch
20from torch.distributions import Categorical
21
22from labml_nn.sampling import Sampler25class TemperatureSampler(Sampler):temperature
はサンプリングする温度29    def __init__(self, temperature: float = 1.0):33        self.temperature = temperatureロジットからのサンプル
35    def __call__(self, logits: torch.Tensor):温度調整済みロジットによるカテゴリ分布の作成
41        dist = Categorical(logits=logits / self.temperature)[サンプル]
44        return dist.sample()