__call__ -> forward

This commit is contained in:
Varuna Jayasiri
2021-08-19 15:53:58 +05:30
parent 0a4b5b6822
commit a8e3695da3
4 changed files with 15 additions and 15 deletions

View File

@ -211,7 +211,7 @@ class EncoderRNN(Module):
# Head to get $\hat{\sigma}$
self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
def __call__(self, inputs: torch.Tensor, state=None):
def forward(self, inputs: torch.Tensor, state=None):
# The hidden state of the bidirectional LSTM is the concatenation of the
# output of the last token in the forward direction and
# first token in the reverse direction, which is what we want.
@ -269,7 +269,7 @@ class DecoderRNN(Module):
self.n_distributions = n_distributions
self.dec_hidden_size = dec_hidden_size
def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
# Calculate the initial state
if state is None:
# $[h_0; c_0] = \tanh(W_{z}z + b_z)$
@ -314,7 +314,7 @@ class ReconstructionLoss(Module):
## Reconstruction Loss
"""
def __call__(self, mask: torch.Tensor, target: torch.Tensor,
def forward(self, mask: torch.Tensor, target: torch.Tensor,
dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
# Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$
pi, mix = dist.get_distribution()
@ -355,7 +355,7 @@ class KLDivLoss(Module):
This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$
"""
def __call__(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
# $$L_{KL} = - \frac{1}{2 N_z} \bigg( 1 + \hat{\sigma} - \mu^2 - \exp(\hat{\sigma}) \bigg)$$
return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))

View File

@ -145,7 +145,7 @@ class DPFP(Module):
self.relu = nn.ReLU()
self.eps = eps
def __call__(self, k: torch.Tensor):
def forward(self, k: torch.Tensor):
# Get $\color{lightgreen}{\phi(k)}$
k = self.dpfp(k)
# Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$
@ -228,7 +228,7 @@ class FastWeightsAttention(Module):
# Dropout
self.dropout = nn.Dropout(dropout_prob)
def __call__(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
# Get the number of steps $L$
seq_len = x.shape[0]
# $\color{lightgreen}{\phi'(q^{(i)})}$ for all steps and heads
@ -291,7 +291,7 @@ class FastWeightsAttentionTransformerLayer(Module):
self.norm_self_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def __call__(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
# Calculate fast weights self attention
attn = self.attn(x)
# Add the self attention results
@ -319,7 +319,7 @@ class FastWeightsAttentionTransformer(Module):
# Final normalization layer
self.norm = nn.LayerNorm([layer.size])
def __call__(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
for i, layer in enumerate(self.layers):
# Get layer output
x = layer(x)

View File

@ -43,7 +43,7 @@ class FastWeightsAttention(Module):
# Dropout
self.dropout = nn.Dropout(dropout_prob)
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
query = self.phi(self.query(x))
key = self.phi(self.key(x))
value = self.value(x)
@ -84,7 +84,7 @@ class FastWeightsAttentionTransformerLayer(Module):
self.norm_self_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
attn, weights = self.attn(x, weights)
# Add the self attention results
x = x + self.dropout(attn)
@ -108,7 +108,7 @@ class FastWeightsAttentionTransformer(Module):
# Final normalization layer
self.norm = nn.LayerNorm([layer.size])
def __call__(self, x_seq: torch.Tensor):
def forward(self, x_seq: torch.Tensor):
# Split the input to a list along the sequence axis
x_seq = torch.unbind(x_seq, dim=0)
# List to store the outputs

View File

@ -75,7 +75,7 @@ class PatchEmbeddings(Module):
# transformation on each patch.
self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
def __call__(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
"""
* `x` is the input image of shape `[batch_size, channels, height, width]`
"""
@ -109,7 +109,7 @@ class LearnedPositionalEmbeddings(Module):
# Positional embeddings for each location
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
def __call__(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
"""
* `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
"""
@ -141,7 +141,7 @@ class ClassificationHead(Module):
# Second layer
self.linear2 = nn.Linear(n_hidden, n_classes)
def __call__(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
"""
* `x` is the transformer encoding for `[CLS]` token
"""
@ -187,7 +187,7 @@ class VisionTransformer(Module):
# Final normalization layer
self.ln = nn.LayerNorm([transformer_layer.size])
def __call__(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
"""
* `x` is the input image of shape `[batch_size, channels, height, width]`
"""