diff --git a/labml_nn/sketch_rnn/__init__.py b/labml_nn/sketch_rnn/__init__.py index c3caf948..275ea9da 100644 --- a/labml_nn/sketch_rnn/__init__.py +++ b/labml_nn/sketch_rnn/__init__.py @@ -211,7 +211,7 @@ class EncoderRNN(Module): # Head to get $\hat{\sigma}$ self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z) - def __call__(self, inputs: torch.Tensor, state=None): + def forward(self, inputs: torch.Tensor, state=None): # The hidden state of the bidirectional LSTM is the concatenation of the # output of the last token in the forward direction and # first token in the reverse direction, which is what we want. @@ -269,7 +269,7 @@ class DecoderRNN(Module): self.n_distributions = n_distributions self.dec_hidden_size = dec_hidden_size - def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]): + def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]): # Calculate the initial state if state is None: # $[h_0; c_0] = \tanh(W_{z}z + b_z)$ @@ -314,7 +314,7 @@ class ReconstructionLoss(Module): ## Reconstruction Loss """ - def __call__(self, mask: torch.Tensor, target: torch.Tensor, + def forward(self, mask: torch.Tensor, target: torch.Tensor, dist: 'BivariateGaussianMixture', q_logits: torch.Tensor): # Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$ pi, mix = dist.get_distribution() @@ -355,7 +355,7 @@ class KLDivLoss(Module): This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$ """ - def __call__(self, sigma_hat: torch.Tensor, mu: torch.Tensor): + def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor): # $$L_{KL} = - \frac{1}{2 N_z} \bigg( 1 + \hat{\sigma} - \mu^2 - \exp(\hat{\sigma}) \bigg)$$ return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat)) diff --git a/labml_nn/transformers/fast_weights/__init__.py b/labml_nn/transformers/fast_weights/__init__.py index 279cd42c..52d4bf3d 100644 --- a/labml_nn/transformers/fast_weights/__init__.py +++ b/labml_nn/transformers/fast_weights/__init__.py @@ -145,7 +145,7 @@ class DPFP(Module): self.relu = nn.ReLU() self.eps = eps - def __call__(self, k: torch.Tensor): + def forward(self, k: torch.Tensor): # Get $\color{lightgreen}{\phi(k)}$ k = self.dpfp(k) # Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$ @@ -228,7 +228,7 @@ class FastWeightsAttention(Module): # Dropout self.dropout = nn.Dropout(dropout_prob) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): # Get the number of steps $L$ seq_len = x.shape[0] # $\color{lightgreen}{\phi'(q^{(i)})}$ for all steps and heads @@ -291,7 +291,7 @@ class FastWeightsAttentionTransformerLayer(Module): self.norm_self_attn = nn.LayerNorm([d_model]) self.norm_ff = nn.LayerNorm([d_model]) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): # Calculate fast weights self attention attn = self.attn(x) # Add the self attention results @@ -319,7 +319,7 @@ class FastWeightsAttentionTransformer(Module): # Final normalization layer self.norm = nn.LayerNorm([layer.size]) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): for i, layer in enumerate(self.layers): # Get layer output x = layer(x) diff --git a/labml_nn/transformers/fast_weights/token_wise.py b/labml_nn/transformers/fast_weights/token_wise.py index a73e734d..77df47a3 100644 --- a/labml_nn/transformers/fast_weights/token_wise.py +++ b/labml_nn/transformers/fast_weights/token_wise.py @@ -43,7 +43,7 @@ class FastWeightsAttention(Module): # Dropout self.dropout = nn.Dropout(dropout_prob) - def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]): + def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]): query = self.phi(self.query(x)) key = self.phi(self.key(x)) value = self.value(x) @@ -84,7 +84,7 @@ class FastWeightsAttentionTransformerLayer(Module): self.norm_self_attn = nn.LayerNorm([d_model]) self.norm_ff = nn.LayerNorm([d_model]) - def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]): + def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]): attn, weights = self.attn(x, weights) # Add the self attention results x = x + self.dropout(attn) @@ -108,7 +108,7 @@ class FastWeightsAttentionTransformer(Module): # Final normalization layer self.norm = nn.LayerNorm([layer.size]) - def __call__(self, x_seq: torch.Tensor): + def forward(self, x_seq: torch.Tensor): # Split the input to a list along the sequence axis x_seq = torch.unbind(x_seq, dim=0) # List to store the outputs diff --git a/labml_nn/transformers/vit/__init__.py b/labml_nn/transformers/vit/__init__.py index 3fd991c4..8c823adc 100644 --- a/labml_nn/transformers/vit/__init__.py +++ b/labml_nn/transformers/vit/__init__.py @@ -75,7 +75,7 @@ class PatchEmbeddings(Module): # transformation on each patch. self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ * `x` is the input image of shape `[batch_size, channels, height, width]` """ @@ -109,7 +109,7 @@ class LearnedPositionalEmbeddings(Module): # Positional embeddings for each location self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ * `x` is the patch embeddings of shape `[patches, batch_size, d_model]` """ @@ -141,7 +141,7 @@ class ClassificationHead(Module): # Second layer self.linear2 = nn.Linear(n_hidden, n_classes) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ * `x` is the transformer encoding for `[CLS]` token """ @@ -187,7 +187,7 @@ class VisionTransformer(Module): # Final normalization layer self.ln = nn.LayerNorm([transformer_layer.size]) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ * `x` is the input image of shape `[batch_size, channels, height, width]` """