1import torch
2from torch import nn
3
4from labml_helpers.module import Module
7class Swish(Module):
8    def __init__(self):
9        super().__init__()
10        self.sigmoid = nn.Sigmoid()
12    def forward(self, x: torch.Tensor) -> torch.Tensor:
13        return x * self.sigmoid(x)