__call__ -> forward

This commit is contained in:
Varuna Jayasiri
2021-08-19 15:45:59 +05:30
parent f1fe7087f1
commit eaa248c9e6
4 changed files with 8 additions and 8 deletions

View File

@ -107,7 +107,7 @@ class ParityPonderGRU(Module):
# An option to set during inference so that computation is actually halted at inference time
self.is_halt = False
def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
* `x` is the input of shape `[batch_size, n_elems]`
@ -194,7 +194,7 @@ class ReconstructionLoss(Module):
super().__init__()
self.loss_func = loss_func
def __call__(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):
def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):
"""
* `p` is $p_1 \dots p_N$ in a tensor of shape `[N, batch_size]`
* `y_hat` is $\hat{y}_1 \dots \hat{y}_N$ in a tensor of shape `[N, batch_size, ...]`
@ -254,7 +254,7 @@ class RegularizationLoss(Module):
# KL-divergence loss
self.kl_div = nn.KLDivLoss(reduction='batchmean')
def __call__(self, p: torch.Tensor):
def forward(self, p: torch.Tensor):
"""
* `p` is $p_1 \dots p_N$ in a tensor of shape `[N, batch_size]`
"""

View File

@ -56,7 +56,7 @@ class Squash(Module):
super().__init__()
self.epsilon = epsilon
def __call__(self, s: torch.Tensor):
def forward(self, s: torch.Tensor):
"""
The shape of `s` is `[batch_size, n_capsules, n_features]`
"""
@ -100,7 +100,7 @@ class Router(Module):
# lower layer to each capsule in this layer
self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)
def __call__(self, u: torch.Tensor):
def forward(self, u: torch.Tensor):
"""
The shape of `u` is `[batch_size, n_capsules, n_features]`.
These are the capsules from the lower layer.
@ -162,7 +162,7 @@ class MarginLoss(Module):
self.lambda_ = lambda_
self.n_labels = n_labels
def __call__(self, v: torch.Tensor, labels: torch.Tensor):
def forward(self, v: torch.Tensor, labels: torch.Tensor):
"""
`v`, $\mathbf{v}_j$ are the squashed output capsules.
This has shape `[batch_size, n_labels, n_features]`; that is, there is a capsule for each label.

View File

@ -104,7 +104,7 @@ class CIFAR10VGGModel(Module):
# Final logits layer
self.fc = nn.Linear(in_channels, 10)
def __call__(self, x):
def forward(self, x):
# The VGG layers
x = self.layers(x)
# Reshape for classification layer

View File

@ -34,7 +34,7 @@ class CrossEntropyLoss(Module):
super().__init__()
self.loss = nn.CrossEntropyLoss()
def __call__(self, outputs, targets):
def forward(self, outputs, targets):
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))