1import torch.nn.functional as F
2from torch import nn
4class LabelSmoothingLoss(nn.Module):
5 def __init__(self, epsilon= 0.5, reduction='mean'):
6 super().__init__()
7 self.epsilon = epsilon
8 self.reduction = reduction
10 def forward(self, pred, target):
11 n = pred.size()[-1]
12 log_pred = F.log_softmax(pred, dim=-1)
13 loss = -log_pred.sum(dim=-1).mean()
14 nll = F.nll_loss(log_pred, target, reduction=self.reduction)
15 out = (1-self.epsilon)*nll + self.epsilon*(loss / n)
16 return out