mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 03:43:09 +08:00
17 lines
546 B
Python
17 lines
546 B
Python
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
class LabelSmoothingLoss(nn.Module):
|
|
def __init__(self, epsilon= 0.5, reduction='mean'):
|
|
super().__init__()
|
|
self.epsilon = epsilon
|
|
self.reduction = reduction
|
|
|
|
def forward(self, pred, target):
|
|
n = pred.size()[-1]
|
|
log_pred = F.log_softmax(pred, dim=-1)
|
|
loss = -log_pred.sum(dim=-1).mean()
|
|
nll = F.nll_loss(log_pred, target, reduction=self.reduction)
|
|
out = (1-self.epsilon)*nll + self.epsilon*(loss / n)
|
|
return out
|