1from typing import Tuple
2
3import torch
4
5from labml import experiment, monit
6from labml import logger
7from labml.logger import Text
8from labml_helpers.datasets.text import TextDataset
9from labml_nn.sampling import Sampler
10from labml_nn.sampling.greedy import GreedySampler
11from labml_nn.sampling.nucleus import NucleusSampler
12from labml_nn.sampling.temperature import TemperatureSampler
13from labml_nn.sampling.top_k import TopKSampler
14from labml_nn.transformers.basic.autoregressive_experiment import Configs, AutoregressiveTransformer
17def get_model_dataset(run_uuid: str) -> Tuple[AutoregressiveTransformer, TextDataset]:
18 experiment.evaluate()
19
20 conf = Configs()
21
22 experiment.configs(conf, experiment.load_configs(run_uuid))
23
24 experiment.load(run_uuid)
25
26 experiment.add_pytorch_models({'model': conf.model})
27
28 experiment.start()
29
30 return conf.model, conf.text
33def sample(model, ds, sampler: Sampler, n_samples: int, n_tokens: int, seq_len: int, prompt: str):
34 with torch.no_grad():
35 data = torch.tile(ds.text_to_i(prompt)[:, None], (1, n_samples))
Collect output for printing
38 logs = [[(prompt, Text.meta)] for _ in range(n_samples)]
Sample 25 tokens
40 for i in monit.iterate('Sample', n_tokens):
Tokenize the prompt
42 data = data[-seq_len:]
Get the model output
44 logits, *_ = model(data)
45 logits = logits[-1]
Get the model prediction (greedy)
47 res = sampler(logits)
48 data = torch.cat([data, res[None, :]], dim=0)
Add the prediction for logging
50 for j in range(n_samples):
51 logs[j] += [('' + ds.itos[res[j]], Text.value)]
Print the sampled output
54 for j in range(n_samples):
55 logger.log(logs[j])
58def main():
59 model, ds = get_model_dataset('074d4004cc6b11ecad7a0242ac1c0002')
60 model.eval()
61
62 with monit.section('greedy'):
63 sample(model, ds, GreedySampler(), 4, 32, 128, 'It is')
64
65 with monit.section('temperature=1.'):
66 sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is')
67 with monit.section('temperature=.1'):
68 sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is')
69 with monit.section('temperature=10.'):
70 sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is')
71
72 with monit.section('top_k=5'):
73 sample(model, ds, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, 'It is')
74
75 with monit.section('nucles p=.95'):
76 sample(model, ds, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, 'It is')
77 with monit.section('nucles p=.95'):
78 sample(model, ds, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, 'It is')
79
80
81if __name__ == '__main__':
82 main()