1from typing import Tuple
2
3import torch
4
5from labml import experiment, monit
6from labml import logger
7from labml.logger import Text
8from labml_helpers.datasets.text import TextDataset
9from labml_nn.sampling import Sampler
10from labml_nn.sampling.greedy import GreedySampler
11from labml_nn.sampling.nucleus import NucleusSampler
12from labml_nn.sampling.temperature import TemperatureSampler
13from labml_nn.sampling.top_k import TopKSampler
14from labml_nn.transformers.basic.autoregressive_experiment import Configs, AutoregressiveTransformer
17def get_model_dataset(run_uuid: str) -> Tuple[AutoregressiveTransformer, TextDataset]:
18    experiment.evaluate()
19
20    conf = Configs()
21
22    experiment.configs(conf, experiment.load_configs(run_uuid))
23
24    experiment.load(run_uuid)
25
26    experiment.add_pytorch_models({'model': conf.model})
27
28    experiment.start()
29
30    return conf.model, conf.text
33def sample(model, ds, sampler: Sampler, n_samples: int, n_tokens: int, seq_len: int, prompt: str):
34    with torch.no_grad():
35        data = torch.tile(ds.text_to_i(prompt)[:, None], (1, n_samples))

Collect output for printing

38        logs = [[(prompt, Text.meta)] for _ in range(n_samples)]

Sample 25 tokens

40        for i in monit.iterate('Sample', n_tokens):

Tokenize the prompt

42            data = data[-seq_len:]

Get the model output

44            logits, *_ = model(data)
45            logits = logits[-1]

Get the model prediction (greedy)

47            res = sampler(logits)
48            data = torch.cat([data, res[None, :]], dim=0)

Add the prediction for logging

50            for j in range(n_samples):
51                logs[j] += [('' + ds.itos[res[j]], Text.value)]

Print the sampled output

54    for j in range(n_samples):
55        logger.log(logs[j])
58def main():
59    model, ds = get_model_dataset('074d4004cc6b11ecad7a0242ac1c0002')
60    model.eval()
61
62    with monit.section('greedy'):
63        sample(model, ds, GreedySampler(), 4, 32, 128, 'It is')
64
65    with monit.section('temperature=1.'):
66        sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is')
67    with monit.section('temperature=.1'):
68        sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is')
69    with monit.section('temperature=10.'):
70        sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is')
71
72    with monit.section('top_k=5'):
73        sample(model, ds, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, 'It is')
74
75    with monit.section('nucles p=.95'):
76        sample(model, ds, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, 'It is')
77    with monit.section('nucles p=.95'):
78        sample(model, ds, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, 'It is')
79
80
81if __name__ == '__main__':
82    main()