mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 17:57:14 +08:00
__call__ -> forward
This commit is contained in:
@ -211,7 +211,7 @@ class EncoderRNN(Module):
|
||||
# Head to get $\hat{\sigma}$
|
||||
self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
|
||||
|
||||
def __call__(self, inputs: torch.Tensor, state=None):
|
||||
def forward(self, inputs: torch.Tensor, state=None):
|
||||
# The hidden state of the bidirectional LSTM is the concatenation of the
|
||||
# output of the last token in the forward direction and
|
||||
# first token in the reverse direction, which is what we want.
|
||||
@ -269,7 +269,7 @@ class DecoderRNN(Module):
|
||||
self.n_distributions = n_distributions
|
||||
self.dec_hidden_size = dec_hidden_size
|
||||
|
||||
def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
|
||||
def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
|
||||
# Calculate the initial state
|
||||
if state is None:
|
||||
# $[h_0; c_0] = \tanh(W_{z}z + b_z)$
|
||||
@ -314,7 +314,7 @@ class ReconstructionLoss(Module):
|
||||
## Reconstruction Loss
|
||||
"""
|
||||
|
||||
def __call__(self, mask: torch.Tensor, target: torch.Tensor,
|
||||
def forward(self, mask: torch.Tensor, target: torch.Tensor,
|
||||
dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
|
||||
# Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$
|
||||
pi, mix = dist.get_distribution()
|
||||
@ -355,7 +355,7 @@ class KLDivLoss(Module):
|
||||
This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$
|
||||
"""
|
||||
|
||||
def __call__(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
|
||||
def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
|
||||
# $$L_{KL} = - \frac{1}{2 N_z} \bigg( 1 + \hat{\sigma} - \mu^2 - \exp(\hat{\sigma}) \bigg)$$
|
||||
return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
|
||||
|
||||
|
||||
@ -145,7 +145,7 @@ class DPFP(Module):
|
||||
self.relu = nn.ReLU()
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, k: torch.Tensor):
|
||||
def forward(self, k: torch.Tensor):
|
||||
# Get $\color{lightgreen}{\phi(k)}$
|
||||
k = self.dpfp(k)
|
||||
# Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$
|
||||
@ -228,7 +228,7 @@ class FastWeightsAttention(Module):
|
||||
# Dropout
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Get the number of steps $L$
|
||||
seq_len = x.shape[0]
|
||||
# $\color{lightgreen}{\phi'(q^{(i)})}$ for all steps and heads
|
||||
@ -291,7 +291,7 @@ class FastWeightsAttentionTransformerLayer(Module):
|
||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||
self.norm_ff = nn.LayerNorm([d_model])
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Calculate fast weights self attention
|
||||
attn = self.attn(x)
|
||||
# Add the self attention results
|
||||
@ -319,7 +319,7 @@ class FastWeightsAttentionTransformer(Module):
|
||||
# Final normalization layer
|
||||
self.norm = nn.LayerNorm([layer.size])
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
for i, layer in enumerate(self.layers):
|
||||
# Get layer output
|
||||
x = layer(x)
|
||||
|
||||
@ -43,7 +43,7 @@ class FastWeightsAttention(Module):
|
||||
# Dropout
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||
query = self.phi(self.query(x))
|
||||
key = self.phi(self.key(x))
|
||||
value = self.value(x)
|
||||
@ -84,7 +84,7 @@ class FastWeightsAttentionTransformerLayer(Module):
|
||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||
self.norm_ff = nn.LayerNorm([d_model])
|
||||
|
||||
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||
attn, weights = self.attn(x, weights)
|
||||
# Add the self attention results
|
||||
x = x + self.dropout(attn)
|
||||
@ -108,7 +108,7 @@ class FastWeightsAttentionTransformer(Module):
|
||||
# Final normalization layer
|
||||
self.norm = nn.LayerNorm([layer.size])
|
||||
|
||||
def __call__(self, x_seq: torch.Tensor):
|
||||
def forward(self, x_seq: torch.Tensor):
|
||||
# Split the input to a list along the sequence axis
|
||||
x_seq = torch.unbind(x_seq, dim=0)
|
||||
# List to store the outputs
|
||||
|
||||
@ -75,7 +75,7 @@ class PatchEmbeddings(Module):
|
||||
# transformation on each patch.
|
||||
self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the input image of shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
@ -109,7 +109,7 @@ class LearnedPositionalEmbeddings(Module):
|
||||
# Positional embeddings for each location
|
||||
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
|
||||
"""
|
||||
@ -141,7 +141,7 @@ class ClassificationHead(Module):
|
||||
# Second layer
|
||||
self.linear2 = nn.Linear(n_hidden, n_classes)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the transformer encoding for `[CLS]` token
|
||||
"""
|
||||
@ -187,7 +187,7 @@ class VisionTransformer(Module):
|
||||
# Final normalization layer
|
||||
self.ln = nn.LayerNorm([transformer_layer.size])
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the input image of shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user