1import torch
2from torch import nn
3
4from labml_helpers.module import Module
7class Swish(Module):
8 def __init__(self):
9 super().__init__()
10 self.sigmoid = nn.Sigmoid()
12 def forward(self, x: torch.Tensor) -> torch.Tensor:
13 return x * self.sigmoid(x)