mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
__call__ -> forward
This commit is contained in:
@ -107,7 +107,7 @@ class ParityPonderGRU(Module):
|
||||
# An option to set during inference so that computation is actually halted at inference time
|
||||
self.is_halt = False
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
* `x` is the input of shape `[batch_size, n_elems]`
|
||||
|
||||
@ -194,7 +194,7 @@ class ReconstructionLoss(Module):
|
||||
super().__init__()
|
||||
self.loss_func = loss_func
|
||||
|
||||
def __call__(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):
|
||||
def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):
|
||||
"""
|
||||
* `p` is $p_1 \dots p_N$ in a tensor of shape `[N, batch_size]`
|
||||
* `y_hat` is $\hat{y}_1 \dots \hat{y}_N$ in a tensor of shape `[N, batch_size, ...]`
|
||||
@ -254,7 +254,7 @@ class RegularizationLoss(Module):
|
||||
# KL-divergence loss
|
||||
self.kl_div = nn.KLDivLoss(reduction='batchmean')
|
||||
|
||||
def __call__(self, p: torch.Tensor):
|
||||
def forward(self, p: torch.Tensor):
|
||||
"""
|
||||
* `p` is $p_1 \dots p_N$ in a tensor of shape `[N, batch_size]`
|
||||
"""
|
||||
|
@ -56,7 +56,7 @@ class Squash(Module):
|
||||
super().__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def __call__(self, s: torch.Tensor):
|
||||
def forward(self, s: torch.Tensor):
|
||||
"""
|
||||
The shape of `s` is `[batch_size, n_capsules, n_features]`
|
||||
"""
|
||||
@ -100,7 +100,7 @@ class Router(Module):
|
||||
# lower layer to each capsule in this layer
|
||||
self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)
|
||||
|
||||
def __call__(self, u: torch.Tensor):
|
||||
def forward(self, u: torch.Tensor):
|
||||
"""
|
||||
The shape of `u` is `[batch_size, n_capsules, n_features]`.
|
||||
These are the capsules from the lower layer.
|
||||
@ -162,7 +162,7 @@ class MarginLoss(Module):
|
||||
self.lambda_ = lambda_
|
||||
self.n_labels = n_labels
|
||||
|
||||
def __call__(self, v: torch.Tensor, labels: torch.Tensor):
|
||||
def forward(self, v: torch.Tensor, labels: torch.Tensor):
|
||||
"""
|
||||
`v`, $\mathbf{v}_j$ are the squashed output capsules.
|
||||
This has shape `[batch_size, n_labels, n_features]`; that is, there is a capsule for each label.
|
||||
|
@ -104,7 +104,7 @@ class CIFAR10VGGModel(Module):
|
||||
# Final logits layer
|
||||
self.fc = nn.Linear(in_channels, 10)
|
||||
|
||||
def __call__(self, x):
|
||||
def forward(self, x):
|
||||
# The VGG layers
|
||||
x = self.layers(x)
|
||||
# Reshape for classification layer
|
||||
|
@ -34,7 +34,7 @@ class CrossEntropyLoss(Module):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def __call__(self, outputs, targets):
|
||||
def forward(self, outputs, targets):
|
||||
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user