mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 03:43:09 +08:00
use forward
This commit is contained in:
@ -78,7 +78,7 @@ class FeedForward(Module):
|
||||
# be multiplied by the gate, parameterized by weight $V$ and bias $c$
|
||||
self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
# $f(x W_1 + b_1)$
|
||||
g = self.activation(self.layer1(x))
|
||||
# If gated, $f(x W_1 + b_1) \otimes (x V + b) $
|
||||
|
||||
@ -155,7 +155,7 @@ class FeedbackAttention(Module):
|
||||
# $A_j$
|
||||
return ac + bd
|
||||
|
||||
def __call__(self, *,
|
||||
def forward(self, *,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor):
|
||||
@ -226,7 +226,7 @@ class FeedbackTransformerLayer(Module):
|
||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||
self.norm_ff = nn.LayerNorm([d_model])
|
||||
|
||||
def __call__(self, *,
|
||||
def forward(self, *,
|
||||
x: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor]):
|
||||
@ -272,7 +272,7 @@ class FeedbackTransformer(Module):
|
||||
# Softmax for weights before taking the weighted sum
|
||||
self.softmax = nn.Softmax(0)
|
||||
|
||||
def __call__(self, x_seq: torch.Tensor):
|
||||
def forward(self, x_seq: torch.Tensor):
|
||||
"""
|
||||
* `x_seq` is the input with shape `[seq_len, batch_size, d_model]`
|
||||
"""
|
||||
@ -470,7 +470,7 @@ class FeedbackTransformerKV(Module):
|
||||
# Memory for stacked values
|
||||
self.mem_value = Stack(512)
|
||||
|
||||
def __call__(self, x_seq: torch.Tensor):
|
||||
def forward(self, x_seq: torch.Tensor):
|
||||
"""
|
||||
* `x_seq` is the input with shape `[seq_len, batch_size, d_model]`
|
||||
"""
|
||||
|
||||
@ -41,7 +41,7 @@ class AutoregressiveModel(Module):
|
||||
self.transformer = transformer
|
||||
self.generator = nn.Linear(d_model, n_vocab)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Embed the tokens
|
||||
x = self.src_embed(x)
|
||||
# Run it through the the transformer
|
||||
|
||||
@ -41,7 +41,7 @@ class AutoregressiveModel(Module):
|
||||
# This will be initialized on the first call
|
||||
self.src_mask = None
|
||||
|
||||
def __call__(self, src: torch.Tensor):
|
||||
def forward(self, src: torch.Tensor):
|
||||
# Create subsequent mask, so that the transformer can only pay attention to past tokens.
|
||||
if self.src_mask is None or self.src_mask.size(0) != len(src):
|
||||
self.src_mask = subsequent_mask(len(src)).to(src.device)
|
||||
|
||||
@ -20,6 +20,7 @@ We decided to write a simpler implementation to make it easier readers who are n
|
||||
import dataclasses
|
||||
|
||||
import torch
|
||||
from labml_helpers.module import Module
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
@ -34,7 +35,7 @@ from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, Trans
|
||||
from labml_nn.transformers.utils import subsequent_mask
|
||||
|
||||
|
||||
class AutoregressiveModel(nn.Module):
|
||||
class AutoregressiveModel(Module):
|
||||
"""
|
||||
## Auto regressive model
|
||||
"""
|
||||
@ -51,7 +52,7 @@ class AutoregressiveModel(nn.Module):
|
||||
# This will be initialized on the first call
|
||||
self.src_mask = None
|
||||
|
||||
def __call__(self, src: torch.Tensor):
|
||||
def forward(self, src: torch.Tensor):
|
||||
# Create subsequent mask, so that the transformer can only pay attention to past tokens.
|
||||
if self.src_mask is None or self.src_mask.size(0) != len(src):
|
||||
self.src_mask = subsequent_mask(len(src)).to(src.device)
|
||||
|
||||
@ -67,7 +67,7 @@ class GPT(Module):
|
||||
# The mask will be initialized on the first call
|
||||
self.mask = None
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Create subsequent mask if mask is not initialized
|
||||
# or if the size of the mask is different
|
||||
if self.mask is None or self.mask.size(0) != len(x):
|
||||
|
||||
@ -49,7 +49,7 @@ class AutoregressiveModel(Module):
|
||||
"""
|
||||
return self.encoder.layers[-1].ff_input
|
||||
|
||||
def __call__(self, src: torch.Tensor):
|
||||
def forward(self, src: torch.Tensor):
|
||||
# Create subsequent mask, so that the transformer can only pay attention to past tokens.
|
||||
if self.src_mask is None or self.src_mask.size(0) != len(src):
|
||||
self.src_mask = subsequent_mask(len(src)).to(src.device)
|
||||
|
||||
@ -26,7 +26,7 @@ class LabelSmoothingLoss(Module):
|
||||
self.size = size
|
||||
self.true_dist = None
|
||||
|
||||
def __call__(self, x: torch.Tensor, target: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor, target: torch.Tensor):
|
||||
assert x.shape[1] == self.size
|
||||
true_dist = x.clone()
|
||||
true_dist.fill_(self.smoothing / (self.size - 2))
|
||||
|
||||
@ -42,7 +42,7 @@ class PrepareForMultiHeadAttention(Module):
|
||||
# Number of dimensions in vectors in each head
|
||||
self.d_k = d_k
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
|
||||
# We apply the linear transformation to the last dimension and split that into
|
||||
# the heads.
|
||||
@ -118,7 +118,7 @@ class MultiHeadAttention(Module):
|
||||
# Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
|
||||
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
||||
|
||||
def __call__(self, *,
|
||||
def forward(self, *,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
|
||||
@ -33,7 +33,7 @@ class EmbeddingsWithPositionalEncoding(Module):
|
||||
self.d_model = d_model
|
||||
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
|
||||
return self.linear(x) * math.sqrt(self.d_model) + pe
|
||||
|
||||
@ -51,7 +51,7 @@ class EmbeddingsWithLearnedPositionalEncoding(Module):
|
||||
self.d_model = d_model
|
||||
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
pe = self.positional_encodings[:x.shape[0]]
|
||||
return self.linear(x) * math.sqrt(self.d_model) + pe
|
||||
|
||||
@ -100,7 +100,7 @@ class TransformerLayer(Module):
|
||||
# Whether to save input to the feed forward layer
|
||||
self.is_save_ff_input = False
|
||||
|
||||
def __call__(self, *,
|
||||
def forward(self, *,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
src: torch.Tensor = None,
|
||||
@ -150,7 +150,7 @@ class Encoder(Module):
|
||||
# Final normalization layer
|
||||
self.norm = nn.LayerNorm([layer.size])
|
||||
|
||||
def __call__(self, x: torch.Tensor, mask: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
||||
# Run through each transformer layer
|
||||
for layer in self.layers:
|
||||
x = layer(x=x, mask=mask)
|
||||
@ -172,7 +172,7 @@ class Decoder(Module):
|
||||
# Final normalization layer
|
||||
self.norm = nn.LayerNorm([layer.size])
|
||||
|
||||
def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
|
||||
# Run through each transformer layer
|
||||
for layer in self.layers:
|
||||
x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
|
||||
@ -194,7 +194,7 @@ class Generator(Module):
|
||||
super().__init__()
|
||||
self.projection = nn.Linear(d_model, n_vocab)
|
||||
|
||||
def __call__(self, x):
|
||||
def forward(self, x):
|
||||
return self.projection(x)
|
||||
|
||||
|
||||
@ -219,7 +219,7 @@ class EncoderDecoder(Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
|
||||
def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
|
||||
# Run the source through encoder
|
||||
enc = self.encode(src, src_mask)
|
||||
# Run encodings and targets through decoder
|
||||
|
||||
@ -36,7 +36,7 @@ class PositionalEncoding(Module):
|
||||
|
||||
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
|
||||
x = x + pe
|
||||
x = self.dropout(x)
|
||||
|
||||
@ -80,7 +80,7 @@ class SwitchFeedForward(Module):
|
||||
self.switch = nn.Linear(d_model, n_experts)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the input to the switching module with shape `[seq_len, batch_size, d_model]`
|
||||
"""
|
||||
@ -189,7 +189,7 @@ class SwitchTransformerLayer(Module):
|
||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||
self.norm_ff = nn.LayerNorm([d_model])
|
||||
|
||||
def __call__(self, *,
|
||||
def forward(self, *,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor):
|
||||
# Normalize the vectors before doing self attention
|
||||
@ -221,7 +221,7 @@ class SwitchTransformer(Module):
|
||||
# Final normalization layer
|
||||
self.norm = nn.LayerNorm([layer.size])
|
||||
|
||||
def __call__(self, x: torch.Tensor, mask: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
||||
# Run through each transformer layer
|
||||
counts, route_prob, n_dropped = [], [], []
|
||||
for layer in self.layers:
|
||||
|
||||
@ -34,7 +34,7 @@ class AutoregressiveModel(Module):
|
||||
self.generator = nn.Linear(d_model, n_vocab)
|
||||
self.mask = None
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Initialize the subsequent mask
|
||||
if self.mask is None or self.mask.size(0) != len(x):
|
||||
from labml_nn.transformers.utils import subsequent_mask
|
||||
|
||||
Reference in New Issue
Block a user