mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 10:51:23 +08:00
83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
from labml import experiment, monit
|
|
from labml import logger
|
|
from labml.logger import Text
|
|
from labml_helpers.datasets.text import TextDataset
|
|
from labml_nn.sampling import Sampler
|
|
from labml_nn.sampling.greedy import GreedySampler
|
|
from labml_nn.sampling.nucleus import NucleusSampler
|
|
from labml_nn.sampling.temperature import TemperatureSampler
|
|
from labml_nn.sampling.top_k import TopKSampler
|
|
from labml_nn.transformers.basic.autoregressive_experiment import Configs, AutoregressiveTransformer
|
|
|
|
|
|
def get_model_dataset(run_uuid: str) -> Tuple[AutoregressiveTransformer, TextDataset]:
|
|
experiment.evaluate()
|
|
|
|
conf = Configs()
|
|
|
|
experiment.configs(conf, experiment.load_configs(run_uuid))
|
|
|
|
experiment.load(run_uuid)
|
|
|
|
experiment.add_pytorch_models({'model': conf.model})
|
|
|
|
experiment.start()
|
|
|
|
return conf.model, conf.text
|
|
|
|
|
|
def sample(model, ds, sampler: Sampler, n_samples: int, n_tokens: int, seq_len: int, prompt: str):
|
|
with torch.no_grad():
|
|
data = torch.tile(ds.text_to_i(prompt)[:, None], (1, n_samples))
|
|
|
|
# Collect output for printing
|
|
logs = [[(prompt, Text.meta)] for _ in range(n_samples)]
|
|
# Sample 25 tokens
|
|
for i in monit.iterate('Sample', n_tokens):
|
|
# Tokenize the prompt
|
|
data = data[-seq_len:]
|
|
# Get the model output
|
|
logits, *_ = model(data)
|
|
logits = logits[-1]
|
|
# Get the model prediction (greedy)
|
|
res = sampler(logits)
|
|
data = torch.cat([data, res[None, :]], dim=0)
|
|
# Add the prediction for logging
|
|
for j in range(n_samples):
|
|
logs[j] += [('' + ds.itos[res[j]], Text.value)]
|
|
|
|
# Print the sampled output
|
|
for j in range(n_samples):
|
|
logger.log(logs[j])
|
|
|
|
|
|
def main():
|
|
model, ds = get_model_dataset('074d4004cc6b11ecad7a0242ac1c0002')
|
|
model.eval()
|
|
|
|
with monit.section('greedy'):
|
|
sample(model, ds, GreedySampler(), 4, 32, 128, 'It is')
|
|
|
|
with monit.section('temperature=1.'):
|
|
sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is')
|
|
with monit.section('temperature=.1'):
|
|
sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is')
|
|
with monit.section('temperature=10.'):
|
|
sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is')
|
|
|
|
with monit.section('top_k=5'):
|
|
sample(model, ds, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, 'It is')
|
|
|
|
with monit.section('nucles p=.95'):
|
|
sample(model, ds, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, 'It is')
|
|
with monit.section('nucles p=.95'):
|
|
sample(model, ds, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, 'It is')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|