mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			112 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			112 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Trying out Sampling Techniques for Language Models
 | |
| summary: >
 | |
|  We try out different sampling techniques for language models on HuggingFace's GPT2 model.
 | |
| ---
 | |
| 
 | |
| # Trying out Sampling Techniques for Language Models
 | |
| 
 | |
| * [Greedy Sampling](greedy.html)
 | |
| * [Temperature Sampling](temperature.html)
 | |
| * [Top-k Sampling](top_k.html)
 | |
| * [Nucleus Sampling](nucleus.html)
 | |
| 
 | |
| This experiment uses the above sampling techniques, on HuggingFace's GPT2 model.
 | |
| """
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from labml import monit, logger, lab
 | |
| 
 | |
| from labml.logger import Text
 | |
| 
 | |
| from labml_nn.sampling import Sampler
 | |
| from labml_nn.sampling.greedy import GreedySampler
 | |
| from labml_nn.sampling.nucleus import NucleusSampler
 | |
| from labml_nn.sampling.temperature import TemperatureSampler
 | |
| from labml_nn.sampling.top_k import TopKSampler
 | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel
 | |
| 
 | |
| 
 | |
| @torch.no_grad()
 | |
| def sample(model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer, sampler: Sampler,
 | |
|            n_samples: int, n_tokens: int, seq_len: int, prompt: str):
 | |
|     """
 | |
|     ## Sample from model
 | |
| 
 | |
|     :param model: is the model to sample from
 | |
|     :param tokenizer: is the tokenizer to use
 | |
|     :param sampler: is the sampler to use
 | |
|     :param n_samples: is the number of samples to generate
 | |
|     :param n_tokens: is the number of tokens to generate
 | |
|     :param seq_len: is the maximum sequence length for the model
 | |
|     :param prompt: is the starting prompt
 | |
|     """
 | |
|     # Tokenize the `prompt` and make `n_samples` copies of it
 | |
|     data = torch.tile(torch.tensor(tokenizer.encode(prompt))[None, :], (n_samples, 1))
 | |
| 
 | |
|     # Collect output for printing
 | |
|     logs = [[(prompt, Text.meta)] for _ in range(n_samples)]
 | |
|     # Sample `n_tokens`
 | |
|     for i in monit.iterate('Sample', n_tokens):
 | |
|         # Truncate the data to the maximum sequence length
 | |
|         data = data[-seq_len:]
 | |
|         # Get the model output. The 'logits' has shape `[batch_size, seq_len, n_tokens]`
 | |
|         logits = model(data)[0]
 | |
|         # Get the `logits` of the last token
 | |
|         logits = logits[:, -1]
 | |
|         # Sample from the `logits`
 | |
|         res = sampler(logits)
 | |
|         # Add the sampled token to the data
 | |
|         data = torch.cat([data, res[:, None]], dim=1)
 | |
|         # Decode and add the sampled token for logging
 | |
|         for j in range(n_samples):
 | |
|             logs[j] += [('' + tokenizer.decode(res[j]), Text.value)]
 | |
| 
 | |
|     # Print the sampled outputs
 | |
|     for j in range(n_samples):
 | |
|         logger.log(logs[j])
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     """
 | |
|     ### Try different sampling techniques
 | |
|     """
 | |
| 
 | |
|     # Load the model and tokenizer
 | |
|     with monit.section('Load tokenizer/model'):
 | |
|         tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')
 | |
|         model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')
 | |
|     # Set the model to eval mode
 | |
|     model.eval()
 | |
| 
 | |
|     # Prompts to use for sampling
 | |
|     prompt = 'I saw an interesting dream last night. '
 | |
| 
 | |
|     # [Greedy Sampling](greedy.html)
 | |
|     with monit.section('greedy'):
 | |
|         sample(model, tokenizer, GreedySampler(), 4, 32, 128, prompt)
 | |
| 
 | |
|     # [Temperature Sampling](temperature.html)
 | |
|     with monit.section('temperature=1.'):
 | |
|         sample(model, tokenizer, TemperatureSampler(1.), 4, 32, 128, prompt)
 | |
|     with monit.section('temperature=.1'):
 | |
|         sample(model, tokenizer, TemperatureSampler(.1), 4, 32, 128, prompt)
 | |
|     with monit.section('temperature=10.'):
 | |
|         sample(model, tokenizer, TemperatureSampler(10.), 4, 32, 128, prompt)
 | |
| 
 | |
|     # [Top-k Sampling](top_k.html)
 | |
|     with monit.section('top_k=5'):
 | |
|         sample(model, tokenizer, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, prompt)
 | |
| 
 | |
|     # [Nucleus Sampling](nucleus.html)
 | |
|     with monit.section('nucleus p=.95'):
 | |
|         sample(model, tokenizer, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, prompt)
 | |
|     with monit.section('nucleus p=.1'):
 | |
|         sample(model, tokenizer, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, prompt)
 | |
| 
 | |
| #
 | |
| if __name__ == '__main__':
 | |
|     main()
 | 
