mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
__call__ -> forward
This commit is contained in:
@ -211,7 +211,7 @@ class EncoderRNN(Module):
|
|||||||
# Head to get $\hat{\sigma}$
|
# Head to get $\hat{\sigma}$
|
||||||
self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
|
self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
|
||||||
|
|
||||||
def __call__(self, inputs: torch.Tensor, state=None):
|
def forward(self, inputs: torch.Tensor, state=None):
|
||||||
# The hidden state of the bidirectional LSTM is the concatenation of the
|
# The hidden state of the bidirectional LSTM is the concatenation of the
|
||||||
# output of the last token in the forward direction and
|
# output of the last token in the forward direction and
|
||||||
# first token in the reverse direction, which is what we want.
|
# first token in the reverse direction, which is what we want.
|
||||||
@ -269,7 +269,7 @@ class DecoderRNN(Module):
|
|||||||
self.n_distributions = n_distributions
|
self.n_distributions = n_distributions
|
||||||
self.dec_hidden_size = dec_hidden_size
|
self.dec_hidden_size = dec_hidden_size
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
|
def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
|
||||||
# Calculate the initial state
|
# Calculate the initial state
|
||||||
if state is None:
|
if state is None:
|
||||||
# $[h_0; c_0] = \tanh(W_{z}z + b_z)$
|
# $[h_0; c_0] = \tanh(W_{z}z + b_z)$
|
||||||
@ -314,7 +314,7 @@ class ReconstructionLoss(Module):
|
|||||||
## Reconstruction Loss
|
## Reconstruction Loss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, mask: torch.Tensor, target: torch.Tensor,
|
def forward(self, mask: torch.Tensor, target: torch.Tensor,
|
||||||
dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
|
dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
|
||||||
# Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$
|
# Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$
|
||||||
pi, mix = dist.get_distribution()
|
pi, mix = dist.get_distribution()
|
||||||
@ -355,7 +355,7 @@ class KLDivLoss(Module):
|
|||||||
This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$
|
This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
|
def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
|
||||||
# $$L_{KL} = - \frac{1}{2 N_z} \bigg( 1 + \hat{\sigma} - \mu^2 - \exp(\hat{\sigma}) \bigg)$$
|
# $$L_{KL} = - \frac{1}{2 N_z} \bigg( 1 + \hat{\sigma} - \mu^2 - \exp(\hat{\sigma}) \bigg)$$
|
||||||
return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
|
return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
|
||||||
|
|
||||||
|
|||||||
@ -145,7 +145,7 @@ class DPFP(Module):
|
|||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def __call__(self, k: torch.Tensor):
|
def forward(self, k: torch.Tensor):
|
||||||
# Get $\color{lightgreen}{\phi(k)}$
|
# Get $\color{lightgreen}{\phi(k)}$
|
||||||
k = self.dpfp(k)
|
k = self.dpfp(k)
|
||||||
# Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$
|
# Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$
|
||||||
@ -228,7 +228,7 @@ class FastWeightsAttention(Module):
|
|||||||
# Dropout
|
# Dropout
|
||||||
self.dropout = nn.Dropout(dropout_prob)
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# Get the number of steps $L$
|
# Get the number of steps $L$
|
||||||
seq_len = x.shape[0]
|
seq_len = x.shape[0]
|
||||||
# $\color{lightgreen}{\phi'(q^{(i)})}$ for all steps and heads
|
# $\color{lightgreen}{\phi'(q^{(i)})}$ for all steps and heads
|
||||||
@ -291,7 +291,7 @@ class FastWeightsAttentionTransformerLayer(Module):
|
|||||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||||
self.norm_ff = nn.LayerNorm([d_model])
|
self.norm_ff = nn.LayerNorm([d_model])
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# Calculate fast weights self attention
|
# Calculate fast weights self attention
|
||||||
attn = self.attn(x)
|
attn = self.attn(x)
|
||||||
# Add the self attention results
|
# Add the self attention results
|
||||||
@ -319,7 +319,7 @@ class FastWeightsAttentionTransformer(Module):
|
|||||||
# Final normalization layer
|
# Final normalization layer
|
||||||
self.norm = nn.LayerNorm([layer.size])
|
self.norm = nn.LayerNorm([layer.size])
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
# Get layer output
|
# Get layer output
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|||||||
@ -43,7 +43,7 @@ class FastWeightsAttention(Module):
|
|||||||
# Dropout
|
# Dropout
|
||||||
self.dropout = nn.Dropout(dropout_prob)
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||||
query = self.phi(self.query(x))
|
query = self.phi(self.query(x))
|
||||||
key = self.phi(self.key(x))
|
key = self.phi(self.key(x))
|
||||||
value = self.value(x)
|
value = self.value(x)
|
||||||
@ -84,7 +84,7 @@ class FastWeightsAttentionTransformerLayer(Module):
|
|||||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||||
self.norm_ff = nn.LayerNorm([d_model])
|
self.norm_ff = nn.LayerNorm([d_model])
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||||
attn, weights = self.attn(x, weights)
|
attn, weights = self.attn(x, weights)
|
||||||
# Add the self attention results
|
# Add the self attention results
|
||||||
x = x + self.dropout(attn)
|
x = x + self.dropout(attn)
|
||||||
@ -108,7 +108,7 @@ class FastWeightsAttentionTransformer(Module):
|
|||||||
# Final normalization layer
|
# Final normalization layer
|
||||||
self.norm = nn.LayerNorm([layer.size])
|
self.norm = nn.LayerNorm([layer.size])
|
||||||
|
|
||||||
def __call__(self, x_seq: torch.Tensor):
|
def forward(self, x_seq: torch.Tensor):
|
||||||
# Split the input to a list along the sequence axis
|
# Split the input to a list along the sequence axis
|
||||||
x_seq = torch.unbind(x_seq, dim=0)
|
x_seq = torch.unbind(x_seq, dim=0)
|
||||||
# List to store the outputs
|
# List to store the outputs
|
||||||
|
|||||||
@ -75,7 +75,7 @@ class PatchEmbeddings(Module):
|
|||||||
# transformation on each patch.
|
# transformation on each patch.
|
||||||
self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
|
self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
* `x` is the input image of shape `[batch_size, channels, height, width]`
|
* `x` is the input image of shape `[batch_size, channels, height, width]`
|
||||||
"""
|
"""
|
||||||
@ -109,7 +109,7 @@ class LearnedPositionalEmbeddings(Module):
|
|||||||
# Positional embeddings for each location
|
# Positional embeddings for each location
|
||||||
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
|
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
* `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
|
* `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
|
||||||
"""
|
"""
|
||||||
@ -141,7 +141,7 @@ class ClassificationHead(Module):
|
|||||||
# Second layer
|
# Second layer
|
||||||
self.linear2 = nn.Linear(n_hidden, n_classes)
|
self.linear2 = nn.Linear(n_hidden, n_classes)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
* `x` is the transformer encoding for `[CLS]` token
|
* `x` is the transformer encoding for `[CLS]` token
|
||||||
"""
|
"""
|
||||||
@ -187,7 +187,7 @@ class VisionTransformer(Module):
|
|||||||
# Final normalization layer
|
# Final normalization layer
|
||||||
self.ln = nn.LayerNorm([transformer_layer.size])
|
self.ln = nn.LayerNorm([transformer_layer.size])
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
* `x` is the input image of shape `[batch_size, channels, height, width]`
|
* `x` is the input image of shape `[batch_size, channels, height, width]`
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user