mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			25 lines
		
	
	
		
			543 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			25 lines
		
	
	
		
			543 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Greedy Sampling
 | |
| summary: A PyTorch implementation of greedy sampling from language models.
 | |
| ---
 | |
| 
 | |
| # Greedy Sampling
 | |
| 
 | |
| Here we sample the most likely token from the distribution of logits.
 | |
| 
 | |
| Here's an [experiment](experiment.html) that uses these sampling techniques.
 | |
| """
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from labml_nn.sampling import Sampler
 | |
| 
 | |
| 
 | |
| class GreedySampler(Sampler):
 | |
|     def __call__(self, logits: torch.Tensor):
 | |
|         """
 | |
|         Sample the most likely token from the distribution of logits
 | |
|         """
 | |
|         return logits.argmax(dim=-1)
 | 
