__call__ -> forward

This commit is contained in:
Varuna Jayasiri
2021-08-19 15:48:49 +05:30
parent 02309fa4cc
commit 6ff41d58b9
6 changed files with 11 additions and 10 deletions

View File

@ -85,7 +85,7 @@ class GraphAttentionLayer(Module):
# Dropout layer to be applied for attention
self.dropout = nn.Dropout(dropout)
def __call__(self, h: torch.Tensor, adj_mat: torch.Tensor):
def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
"""
* `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`.
* `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`.

View File

@ -134,7 +134,7 @@ class GAT(Module):
# Dropout
self.dropout = nn.Dropout(dropout)
def __call__(self, x: torch.Tensor, adj_mat: torch.Tensor):
def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
"""
* `x` is the features vectors of shape `[n_nodes, in_features]`
* `adj_mat` is the adjacency matrix of the form

View File

@ -121,7 +121,7 @@ class GraphAttentionV2Layer(Module):
# Dropout layer to be applied for attention
self.dropout = nn.Dropout(dropout)
def __call__(self, h: torch.Tensor, adj_mat: torch.Tensor):
def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
"""
* `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`.
* `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`.

View File

@ -50,7 +50,7 @@ class GATv2(Module):
# Dropout
self.dropout = nn.Dropout(dropout)
def __call__(self, x: torch.Tensor, adj_mat: torch.Tensor):
def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
"""
* `x` is the features vectors of shape `[n_nodes, in_features]`
* `adj_mat` is the adjacency matrix of the form

View File

@ -22,7 +22,7 @@ class AutoregressiveModel(Module):
self.lstm = rnn_model
self.generator = nn.Linear(d_model, n_vocab)
def __call__(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
x = self.src_embed(x)
# Embed the tokens (`src`) and run it through the the transformer
res, state = self.lstm(x)

View File

@ -147,7 +147,7 @@ class HyperLSTMCell(Module):
self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
self.layer_norm_c = nn.LayerNorm(hidden_size)
def __call__(self, x: torch.Tensor,
def forward(self, x: torch.Tensor,
h: torch.Tensor, c: torch.Tensor,
h_hat: torch.Tensor, c_hat: torch.Tensor):
# $$
@ -202,6 +202,7 @@ class HyperLSTM(Module):
"""
# HyperLSTM module
"""
def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
"""
Create a network of `n_layers` of HyperLSTM.
@ -220,7 +221,7 @@ class HyperLSTM(Module):
[HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
range(n_layers - 1)])
def __call__(self, x: torch.Tensor,
def forward(self, x: torch.Tensor,
state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
"""
* `x` has shape `[n_steps, batch_size, input_size]` and