mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			60 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			60 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import matplotlib.pyplot as plt
 | |
| import numpy as np
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| 
 | |
| from labml.helpers.pytorch.module import Module
 | |
| 
 | |
| 
 | |
| class LabelSmoothingLoss(Module):
 | |
|     def __init__(self, size: int, padding_idx: int, smoothing: float = 0.0):
 | |
|         super().__init__()
 | |
|         self.loss = nn.KLDivLoss(reduction='sum')
 | |
|         self.padding_idx = padding_idx
 | |
|         self.confidence = 1.0 - smoothing
 | |
|         self.smoothing = smoothing
 | |
|         self.size = size
 | |
|         self.true_dist = None
 | |
| 
 | |
|     def __call__(self, x: torch.Tensor, target: torch.Tensor):
 | |
|         assert x.size(1) == self.size
 | |
|         true_dist = x.clone()
 | |
|         true_dist.fill_(self.smoothing / (self.size - 2))
 | |
|         true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
 | |
|         true_dist[:, self.padding_idx] = 0
 | |
|         mask = torch.nonzero(target == self.padding_idx, as_tuple=False)
 | |
|         if mask.dim() > 0:
 | |
|             true_dist.index_fill_(0, mask.squeeze(), 0.0)
 | |
|         self.true_dist = true_dist
 | |
|         return self.loss(x, true_dist.detach())
 | |
| 
 | |
| 
 | |
| def _test_label_smoothing():
 | |
|     smooth_loss = LabelSmoothingLoss(5, 0, 0.4)
 | |
|     predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0],
 | |
|                             [0, 0.2, 0.7, 0.1, 0],
 | |
|                             [0, 0.2, 0.7, 0.1, 0]], dtype=torch.float)
 | |
|     _ = smooth_loss(predict.log(),
 | |
|                     torch.tensor([2, 1, 0], dtype=torch.long))
 | |
| 
 | |
|     # Show the target distributions expected by the system.
 | |
|     plt.imshow(smooth_loss.true_dist)
 | |
|     plt.show()
 | |
| 
 | |
|     smooth_loss = LabelSmoothingLoss(5, 0, 0.1)
 | |
| 
 | |
|     def loss_sample(x):
 | |
|         d = x + 3 * 1
 | |
|         predict2 = torch.tensor([[0, x / d, 1 / d, 1 / d, 1 / d],
 | |
|                                  ], dtype=torch.float)
 | |
|         # print(predict)
 | |
|         return smooth_loss(predict2.log(),
 | |
|                            torch.tensor([1], dtype=torch.long)).item()
 | |
| 
 | |
|     plt.plot(np.arange(1, 100), [loss_sample(x) for x in range(1, 100)])
 | |
|     plt.show()
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     _test_label_smoothing()
 | 
