mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			83 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			83 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Tuple
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from labml import experiment, monit
 | |
| from labml import logger
 | |
| from labml.logger import Text
 | |
| from labml_nn.helpers.datasets import TextDataset
 | |
| from labml_nn.sampling import Sampler
 | |
| from labml_nn.sampling.greedy import GreedySampler
 | |
| from labml_nn.sampling.nucleus import NucleusSampler
 | |
| from labml_nn.sampling.temperature import TemperatureSampler
 | |
| from labml_nn.sampling.top_k import TopKSampler
 | |
| from labml_nn.transformers.basic.autoregressive_experiment import Configs, AutoregressiveTransformer
 | |
| 
 | |
| 
 | |
| def get_model_dataset(run_uuid: str) -> Tuple[AutoregressiveTransformer, TextDataset]:
 | |
|     experiment.evaluate()
 | |
| 
 | |
|     conf = Configs()
 | |
| 
 | |
|     experiment.configs(conf, experiment.load_configs(run_uuid))
 | |
| 
 | |
|     experiment.load(run_uuid)
 | |
| 
 | |
|     experiment.add_pytorch_models({'model': conf.model})
 | |
| 
 | |
|     experiment.start()
 | |
| 
 | |
|     return conf.model, conf.text
 | |
| 
 | |
| 
 | |
| def sample(model, ds, sampler: Sampler, n_samples: int, n_tokens: int, seq_len: int, prompt: str):
 | |
|     with torch.no_grad():
 | |
|         data = torch.tile(ds.text_to_i(prompt)[:, None], (1, n_samples))
 | |
| 
 | |
|         # Collect output for printing
 | |
|         logs = [[(prompt, Text.meta)] for _ in range(n_samples)]
 | |
|         # Sample 25 tokens
 | |
|         for i in monit.iterate('Sample', n_tokens):
 | |
|             # Tokenize the prompt
 | |
|             data = data[-seq_len:]
 | |
|             # Get the model output
 | |
|             logits, *_ = model(data)
 | |
|             logits = logits[-1]
 | |
|             # Get the model prediction (greedy)
 | |
|             res = sampler(logits)
 | |
|             data = torch.cat([data, res[None, :]], dim=0)
 | |
|             # Add the prediction for logging
 | |
|             for j in range(n_samples):
 | |
|                 logs[j] += [('' + ds.itos[res[j]], Text.value)]
 | |
| 
 | |
|     # Print the sampled output
 | |
|     for j in range(n_samples):
 | |
|         logger.log(logs[j])
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     model, ds = get_model_dataset('074d4004cc6b11ecad7a0242ac1c0002')
 | |
|     model.eval()
 | |
| 
 | |
|     with monit.section('greedy'):
 | |
|         sample(model, ds, GreedySampler(), 4, 32, 128, 'It is')
 | |
| 
 | |
|     with monit.section('temperature=1.'):
 | |
|         sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is')
 | |
|     with monit.section('temperature=.1'):
 | |
|         sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is')
 | |
|     with monit.section('temperature=10.'):
 | |
|         sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is')
 | |
| 
 | |
|     with monit.section('top_k=5'):
 | |
|         sample(model, ds, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, 'It is')
 | |
| 
 | |
|     with monit.section('nucles p=.95'):
 | |
|         sample(model, ds, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, 'It is')
 | |
|     with monit.section('nucles p=.95'):
 | |
|         sample(model, ds, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, 'It is')
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     main()
 | 
