mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	__call__ -> forward
This commit is contained in:
		| @ -85,7 +85,7 @@ class GraphAttentionLayer(Module): | |||||||
|         # Dropout layer to be applied for attention |         # Dropout layer to be applied for attention | ||||||
|         self.dropout = nn.Dropout(dropout) |         self.dropout = nn.Dropout(dropout) | ||||||
|  |  | ||||||
|     def __call__(self, h: torch.Tensor, adj_mat: torch.Tensor): |     def forward(self, h: torch.Tensor, adj_mat: torch.Tensor): | ||||||
|         """ |         """ | ||||||
|         * `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`. |         * `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`. | ||||||
|         * `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`. |         * `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`. | ||||||
|  | |||||||
| @ -134,7 +134,7 @@ class GAT(Module): | |||||||
|         # Dropout |         # Dropout | ||||||
|         self.dropout = nn.Dropout(dropout) |         self.dropout = nn.Dropout(dropout) | ||||||
|  |  | ||||||
|     def __call__(self, x: torch.Tensor, adj_mat: torch.Tensor): |     def forward(self, x: torch.Tensor, adj_mat: torch.Tensor): | ||||||
|         """ |         """ | ||||||
|         * `x` is the features vectors of shape `[n_nodes, in_features]` |         * `x` is the features vectors of shape `[n_nodes, in_features]` | ||||||
|         * `adj_mat` is the adjacency matrix of the form |         * `adj_mat` is the adjacency matrix of the form | ||||||
|  | |||||||
| @ -121,7 +121,7 @@ class GraphAttentionV2Layer(Module): | |||||||
|         # Dropout layer to be applied for attention |         # Dropout layer to be applied for attention | ||||||
|         self.dropout = nn.Dropout(dropout) |         self.dropout = nn.Dropout(dropout) | ||||||
|  |  | ||||||
|     def __call__(self, h: torch.Tensor, adj_mat: torch.Tensor): |     def forward(self, h: torch.Tensor, adj_mat: torch.Tensor): | ||||||
|         """ |         """ | ||||||
|         * `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`. |         * `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`. | ||||||
|         * `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`. |         * `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`. | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ class GATv2(Module): | |||||||
|         # Dropout |         # Dropout | ||||||
|         self.dropout = nn.Dropout(dropout) |         self.dropout = nn.Dropout(dropout) | ||||||
|  |  | ||||||
|     def __call__(self, x: torch.Tensor, adj_mat: torch.Tensor): |     def forward(self, x: torch.Tensor, adj_mat: torch.Tensor): | ||||||
|         """ |         """ | ||||||
|         * `x` is the features vectors of shape `[n_nodes, in_features]` |         * `x` is the features vectors of shape `[n_nodes, in_features]` | ||||||
|         * `adj_mat` is the adjacency matrix of the form |         * `adj_mat` is the adjacency matrix of the form | ||||||
|  | |||||||
| @ -22,7 +22,7 @@ class AutoregressiveModel(Module): | |||||||
|         self.lstm = rnn_model |         self.lstm = rnn_model | ||||||
|         self.generator = nn.Linear(d_model, n_vocab) |         self.generator = nn.Linear(d_model, n_vocab) | ||||||
|  |  | ||||||
|     def __call__(self, x: torch.Tensor): |     def forward(self, x: torch.Tensor): | ||||||
|         x = self.src_embed(x) |         x = self.src_embed(x) | ||||||
|         # Embed the tokens (`src`) and run it through the the transformer |         # Embed the tokens (`src`) and run it through the the transformer | ||||||
|         res, state = self.lstm(x) |         res, state = self.lstm(x) | ||||||
|  | |||||||
| @ -147,7 +147,7 @@ class HyperLSTMCell(Module): | |||||||
|         self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)]) |         self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)]) | ||||||
|         self.layer_norm_c = nn.LayerNorm(hidden_size) |         self.layer_norm_c = nn.LayerNorm(hidden_size) | ||||||
|  |  | ||||||
|     def __call__(self, x: torch.Tensor, |     def forward(self, x: torch.Tensor, | ||||||
|                 h: torch.Tensor, c: torch.Tensor, |                 h: torch.Tensor, c: torch.Tensor, | ||||||
|                 h_hat: torch.Tensor, c_hat: torch.Tensor): |                 h_hat: torch.Tensor, c_hat: torch.Tensor): | ||||||
|         # $$ |         # $$ | ||||||
| @ -202,6 +202,7 @@ class HyperLSTM(Module): | |||||||
|     """ |     """ | ||||||
|     # HyperLSTM module |     # HyperLSTM module | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int): |     def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int): | ||||||
|         """ |         """ | ||||||
|         Create a network of `n_layers` of HyperLSTM. |         Create a network of `n_layers` of HyperLSTM. | ||||||
| @ -220,7 +221,7 @@ class HyperLSTM(Module): | |||||||
|                                    [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in |                                    [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in | ||||||
|                                     range(n_layers - 1)]) |                                     range(n_layers - 1)]) | ||||||
|  |  | ||||||
|     def __call__(self, x: torch.Tensor, |     def forward(self, x: torch.Tensor, | ||||||
|                 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None): |                 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None): | ||||||
|         """ |         """ | ||||||
|         * `x` has shape `[n_steps, batch_size, input_size]` and |         * `x` has shape `[n_steps, batch_size, input_size]` and | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri