__call__ -> forward

This commit is contained in:
Varuna Jayasiri
2021-08-19 15:48:49 +05:30
parent 02309fa4cc
commit 6ff41d58b9
6 changed files with 11 additions and 10 deletions

View File

@ -85,7 +85,7 @@ class GraphAttentionLayer(Module):
# Dropout layer to be applied for attention # Dropout layer to be applied for attention
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def __call__(self, h: torch.Tensor, adj_mat: torch.Tensor): def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
""" """
* `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`. * `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`.
* `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`. * `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`.

View File

@ -134,7 +134,7 @@ class GAT(Module):
# Dropout # Dropout
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def __call__(self, x: torch.Tensor, adj_mat: torch.Tensor): def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
""" """
* `x` is the features vectors of shape `[n_nodes, in_features]` * `x` is the features vectors of shape `[n_nodes, in_features]`
* `adj_mat` is the adjacency matrix of the form * `adj_mat` is the adjacency matrix of the form

View File

@ -121,7 +121,7 @@ class GraphAttentionV2Layer(Module):
# Dropout layer to be applied for attention # Dropout layer to be applied for attention
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def __call__(self, h: torch.Tensor, adj_mat: torch.Tensor): def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
""" """
* `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`. * `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`.
* `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`. * `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`.

View File

@ -50,7 +50,7 @@ class GATv2(Module):
# Dropout # Dropout
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def __call__(self, x: torch.Tensor, adj_mat: torch.Tensor): def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
""" """
* `x` is the features vectors of shape `[n_nodes, in_features]` * `x` is the features vectors of shape `[n_nodes, in_features]`
* `adj_mat` is the adjacency matrix of the form * `adj_mat` is the adjacency matrix of the form

View File

@ -22,7 +22,7 @@ class AutoregressiveModel(Module):
self.lstm = rnn_model self.lstm = rnn_model
self.generator = nn.Linear(d_model, n_vocab) self.generator = nn.Linear(d_model, n_vocab)
def __call__(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = self.src_embed(x) x = self.src_embed(x)
# Embed the tokens (`src`) and run it through the the transformer # Embed the tokens (`src`) and run it through the the transformer
res, state = self.lstm(x) res, state = self.lstm(x)

View File

@ -147,9 +147,9 @@ class HyperLSTMCell(Module):
self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)]) self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
self.layer_norm_c = nn.LayerNorm(hidden_size) self.layer_norm_c = nn.LayerNorm(hidden_size)
def __call__(self, x: torch.Tensor, def forward(self, x: torch.Tensor,
h: torch.Tensor, c: torch.Tensor, h: torch.Tensor, c: torch.Tensor,
h_hat: torch.Tensor, c_hat: torch.Tensor): h_hat: torch.Tensor, c_hat: torch.Tensor):
# $$ # $$
# \hat{x}_t = \begin{pmatrix} # \hat{x}_t = \begin{pmatrix}
# h_{t-1} \\ # h_{t-1} \\
@ -202,6 +202,7 @@ class HyperLSTM(Module):
""" """
# HyperLSTM module # HyperLSTM module
""" """
def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int): def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
""" """
Create a network of `n_layers` of HyperLSTM. Create a network of `n_layers` of HyperLSTM.
@ -220,8 +221,8 @@ class HyperLSTM(Module):
[HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
range(n_layers - 1)]) range(n_layers - 1)])
def __call__(self, x: torch.Tensor, def forward(self, x: torch.Tensor,
state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None): state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
""" """
* `x` has shape `[n_steps, batch_size, input_size]` and * `x` has shape `[n_steps, batch_size, input_size]` and
* `state` is a tuple of $h, c, \hat{h}, \hat{c}$. * `state` is a tuple of $h, c, \hat{h}, \hat{c}$.