mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	__call__ -> forward
This commit is contained in:
		| @ -211,7 +211,7 @@ class EncoderRNN(Module): | ||||
|         # Head to get $\hat{\sigma}$ | ||||
|         self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z) | ||||
|  | ||||
|     def __call__(self, inputs: torch.Tensor, state=None): | ||||
|     def forward(self, inputs: torch.Tensor, state=None): | ||||
|         # The hidden state of the bidirectional LSTM is the concatenation of the | ||||
|         # output of the last token in the forward direction and | ||||
|         # first token in the reverse direction, which is what we want. | ||||
| @ -269,7 +269,7 @@ class DecoderRNN(Module): | ||||
|         self.n_distributions = n_distributions | ||||
|         self.dec_hidden_size = dec_hidden_size | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]): | ||||
|     def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]): | ||||
|         # Calculate the initial state | ||||
|         if state is None: | ||||
|             # $[h_0; c_0] = \tanh(W_{z}z + b_z)$ | ||||
| @ -314,7 +314,7 @@ class ReconstructionLoss(Module): | ||||
|     ## Reconstruction Loss | ||||
|     """ | ||||
|  | ||||
|     def __call__(self, mask: torch.Tensor, target: torch.Tensor, | ||||
|     def forward(self, mask: torch.Tensor, target: torch.Tensor, | ||||
|                  dist: 'BivariateGaussianMixture', q_logits: torch.Tensor): | ||||
|         # Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$ | ||||
|         pi, mix = dist.get_distribution() | ||||
| @ -355,7 +355,7 @@ class KLDivLoss(Module): | ||||
|     This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$ | ||||
|     """ | ||||
|  | ||||
|     def __call__(self, sigma_hat: torch.Tensor, mu: torch.Tensor): | ||||
|     def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor): | ||||
|         # $$L_{KL} = - \frac{1}{2 N_z} \bigg( 1 + \hat{\sigma} - \mu^2 - \exp(\hat{\sigma}) \bigg)$$ | ||||
|         return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat)) | ||||
|  | ||||
|  | ||||
| @ -145,7 +145,7 @@ class DPFP(Module): | ||||
|         self.relu = nn.ReLU() | ||||
|         self.eps = eps | ||||
|  | ||||
|     def __call__(self, k: torch.Tensor): | ||||
|     def forward(self, k: torch.Tensor): | ||||
|         # Get $\color{lightgreen}{\phi(k)}$ | ||||
|         k = self.dpfp(k) | ||||
|         # Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$ | ||||
| @ -228,7 +228,7 @@ class FastWeightsAttention(Module): | ||||
|         # Dropout | ||||
|         self.dropout = nn.Dropout(dropout_prob) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         # Get the number of steps $L$ | ||||
|         seq_len = x.shape[0] | ||||
|         # $\color{lightgreen}{\phi'(q^{(i)})}$ for all steps and heads | ||||
| @ -291,7 +291,7 @@ class FastWeightsAttentionTransformerLayer(Module): | ||||
|         self.norm_self_attn = nn.LayerNorm([d_model]) | ||||
|         self.norm_ff = nn.LayerNorm([d_model]) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         # Calculate fast weights self attention | ||||
|         attn = self.attn(x) | ||||
|         # Add the self attention results | ||||
| @ -319,7 +319,7 @@ class FastWeightsAttentionTransformer(Module): | ||||
|         # Final normalization layer | ||||
|         self.norm = nn.LayerNorm([layer.size]) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             # Get layer output | ||||
|             x = layer(x) | ||||
|  | ||||
| @ -43,7 +43,7 @@ class FastWeightsAttention(Module): | ||||
|         # Dropout | ||||
|         self.dropout = nn.Dropout(dropout_prob) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]): | ||||
|     def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]): | ||||
|         query = self.phi(self.query(x)) | ||||
|         key = self.phi(self.key(x)) | ||||
|         value = self.value(x) | ||||
| @ -84,7 +84,7 @@ class FastWeightsAttentionTransformerLayer(Module): | ||||
|         self.norm_self_attn = nn.LayerNorm([d_model]) | ||||
|         self.norm_ff = nn.LayerNorm([d_model]) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]): | ||||
|     def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]): | ||||
|         attn, weights = self.attn(x, weights) | ||||
|         # Add the self attention results | ||||
|         x = x + self.dropout(attn) | ||||
| @ -108,7 +108,7 @@ class FastWeightsAttentionTransformer(Module): | ||||
|         # Final normalization layer | ||||
|         self.norm = nn.LayerNorm([layer.size]) | ||||
|  | ||||
|     def __call__(self, x_seq: torch.Tensor): | ||||
|     def forward(self, x_seq: torch.Tensor): | ||||
|         # Split the input to a list along the sequence axis | ||||
|         x_seq = torch.unbind(x_seq, dim=0) | ||||
|         # List to store the outputs | ||||
|  | ||||
| @ -75,7 +75,7 @@ class PatchEmbeddings(Module): | ||||
|         # transformation on each patch. | ||||
|         self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         """ | ||||
|         * `x` is the input image of shape `[batch_size, channels, height, width]` | ||||
|         """ | ||||
| @ -109,7 +109,7 @@ class LearnedPositionalEmbeddings(Module): | ||||
|         # Positional embeddings for each location | ||||
|         self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         """ | ||||
|         * `x` is the patch embeddings of shape `[patches, batch_size, d_model]` | ||||
|         """ | ||||
| @ -141,7 +141,7 @@ class ClassificationHead(Module): | ||||
|         # Second layer | ||||
|         self.linear2 = nn.Linear(n_hidden, n_classes) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         """ | ||||
|         * `x` is the transformer encoding for `[CLS]` token | ||||
|         """ | ||||
| @ -187,7 +187,7 @@ class VisionTransformer(Module): | ||||
|         # Final normalization layer | ||||
|         self.ln = nn.LayerNorm([transformer_layer.size]) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         """ | ||||
|         * `x` is the input image of shape `[batch_size, channels, height, width]` | ||||
|         """ | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri