mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 18:27:20 +08:00
feedback transformer
This commit is contained in:
197
labml_nn/transformers/feedback/__init__.py
Normal file
197
labml_nn/transformers/feedback/__init__.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
"""
|
||||||
|
---
|
||||||
|
title: Feedback Transformer
|
||||||
|
summary: >
|
||||||
|
This implements the Feedback Transformer in PyTorch with explainations.
|
||||||
|
---
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from labml_nn.transformers.mha import PrepareForMultiHeadAttention
|
||||||
|
from labml_nn.transformers.models import FeedForward
|
||||||
|
from labml_nn.utils import clone_module_list
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareQueryForMultiHeadAttention(Module):
|
||||||
|
"""
|
||||||
|
## Prepare query for multi-head attention
|
||||||
|
|
||||||
|
This module does a linear transformation and splits the vector into given
|
||||||
|
number of heads for multi-head attention.
|
||||||
|
This is used to transform **key**, **query**, and **value** vectors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
|
||||||
|
super().__init__()
|
||||||
|
# Linear layer for linear transform
|
||||||
|
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
|
||||||
|
# Number of heads
|
||||||
|
self.heads = heads
|
||||||
|
# Number of dimensions in vectors in each head
|
||||||
|
self.d_k = d_k
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
# Input has shape `[seq_len, batch_size, d_model]`
|
||||||
|
batch_size, _ = x.shape
|
||||||
|
|
||||||
|
# Linear transform
|
||||||
|
x = self.linear(x)
|
||||||
|
# Split into heads
|
||||||
|
x = x.view(batch_size, self.heads, self.d_k)
|
||||||
|
|
||||||
|
# Output has shape `[seq_len, batch_size, heads, d_k]`
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackAttention(Module):
|
||||||
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.d_k = d_model // heads
|
||||||
|
self.heads = heads
|
||||||
|
|
||||||
|
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||||
|
self.query = PrepareQueryForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||||
|
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||||
|
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||||
|
|
||||||
|
# Output layer
|
||||||
|
self.output = nn.Linear(d_model, d_model)
|
||||||
|
# Dropout
|
||||||
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
# Scaling factor before the softmax
|
||||||
|
self.scale = 1 / math.sqrt(self.d_k)
|
||||||
|
|
||||||
|
# We store attentions so that it can used for logging, or other computations if needed
|
||||||
|
self.attn = None
|
||||||
|
|
||||||
|
self.P = 2 ** 12
|
||||||
|
|
||||||
|
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
|
||||||
|
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
|
||||||
|
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
||||||
|
self.softmax = nn.Softmax(dim=0)
|
||||||
|
|
||||||
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||||
|
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
|
||||||
|
key_pos_bias = self.key_pos_bias[-key.shape[0]:]
|
||||||
|
query_pos_bias = self.query_pos_bias[None, :, :]
|
||||||
|
|
||||||
|
ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
|
||||||
|
bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]
|
||||||
|
|
||||||
|
return ac + bd
|
||||||
|
|
||||||
|
def __call__(self, *,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor):
|
||||||
|
# `query`, `key` and `value` have shape `[seq_len, batch_size, d_model]`
|
||||||
|
batch_size, _ = query.shape
|
||||||
|
|
||||||
|
# Prepare `query`, `key` and `value` for attention computation
|
||||||
|
# These will then have shape `[seq_len, batch_size, heads, d_k]`
|
||||||
|
query = self.query(query)
|
||||||
|
key = self.key(key)
|
||||||
|
value = self.value(value)
|
||||||
|
|
||||||
|
# Compute attention scores $Q K^T$
|
||||||
|
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
|
||||||
|
scores = self.get_scores(query, key)
|
||||||
|
|
||||||
|
# Scale scores $\frac{Q K^T}{\sqrt{d_k}}$
|
||||||
|
scores *= self.scale
|
||||||
|
|
||||||
|
attn = self.softmax(scores)
|
||||||
|
|
||||||
|
# Apply dropout
|
||||||
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
|
# Multiply by values
|
||||||
|
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
|
||||||
|
x = torch.einsum("jbh,jbhd->bhd", attn, value)
|
||||||
|
|
||||||
|
# Save attentions for any other calculations
|
||||||
|
self.attn = attn.detach()
|
||||||
|
|
||||||
|
# Concatenate multiple heads
|
||||||
|
x = x.reshape(batch_size, -1)
|
||||||
|
|
||||||
|
# Output layer
|
||||||
|
return self.output(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackTransformerLayer(Module):
|
||||||
|
def __init__(self, *,
|
||||||
|
d_model: int,
|
||||||
|
attn: FeedbackAttention,
|
||||||
|
feed_forward: FeedForward,
|
||||||
|
dropout_prob: float):
|
||||||
|
super().__init__()
|
||||||
|
self.size = d_model
|
||||||
|
self.attn = attn
|
||||||
|
self.feed_forward = feed_forward
|
||||||
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||||
|
self.norm_ff = nn.LayerNorm([d_model])
|
||||||
|
|
||||||
|
def __call__(self, *,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mem: Optional[torch.Tensor]):
|
||||||
|
# Normalize the vectors before doing self attention
|
||||||
|
z = self.norm_self_attn(x)
|
||||||
|
if mem is not None:
|
||||||
|
# Run through self attention, i.e. keys and values are from self
|
||||||
|
self_attn = self.attn(query=z, key=mem, value=mem)
|
||||||
|
# Add the self attention results
|
||||||
|
x = x + self.dropout(self_attn)
|
||||||
|
|
||||||
|
# Normalize for feed-forward
|
||||||
|
z = self.norm_ff(x)
|
||||||
|
# Pass through the feed-forward network
|
||||||
|
ff = self.feed_forward(z)
|
||||||
|
# Add the feed-forward results back
|
||||||
|
x = x + self.dropout(ff)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackTransformer(Module):
|
||||||
|
"""
|
||||||
|
## Transformer Encoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
|
||||||
|
super().__init__()
|
||||||
|
# Make copies of the transformer layer
|
||||||
|
self.layers = clone_module_list(layer, n_layers)
|
||||||
|
self.norm = nn.LayerNorm([layer.size])
|
||||||
|
self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
|
||||||
|
self.softmax = nn.Softmax(0)
|
||||||
|
|
||||||
|
def __call__(self, x_seq: torch.Tensor):
|
||||||
|
# Run through each transformer layer
|
||||||
|
x_seq = torch.unbind(x_seq, dim=0)
|
||||||
|
res = []
|
||||||
|
mem = []
|
||||||
|
for x in x_seq:
|
||||||
|
emb = [x]
|
||||||
|
mem_tensor = None
|
||||||
|
if mem:
|
||||||
|
mem_tensor = torch.stack(mem)
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x=x, mem=mem_tensor)
|
||||||
|
emb.append(x)
|
||||||
|
emb = torch.stack(emb)
|
||||||
|
mem.append(torch.einsum('lbd,l->bd', emb, self.softmax(self.weights)))
|
||||||
|
# Finally, normalize the vectors
|
||||||
|
res.append(x)
|
||||||
|
|
||||||
|
res = torch.stack(res)
|
||||||
|
return self.norm(res)
|
94
labml_nn/transformers/feedback/experiment.py
Normal file
94
labml_nn/transformers/feedback/experiment.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from labml import experiment
|
||||||
|
from labml.configs import option
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||||
|
|
||||||
|
|
||||||
|
class AutoregressiveModel(Module):
|
||||||
|
"""
|
||||||
|
## Auto regressive model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_vocab: int, d_model: int, transformer: Module):
|
||||||
|
super().__init__()
|
||||||
|
# Token embedding module
|
||||||
|
self.src_embed = nn.Embedding(n_vocab, d_model)
|
||||||
|
self.transformer = transformer
|
||||||
|
self.generator = nn.Linear(d_model, n_vocab)
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
x = self.src_embed(x)
|
||||||
|
# Embed the tokens (`src`) and run it through the the transformer
|
||||||
|
res = self.transformer(x)
|
||||||
|
# Generate logits of the next token
|
||||||
|
return self.generator(res), None
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(NLPAutoRegressionConfigs):
|
||||||
|
"""
|
||||||
|
## Configurations
|
||||||
|
|
||||||
|
The default configs can and will be over-ridden when we start the experiment
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: AutoregressiveModel
|
||||||
|
|
||||||
|
d_model: int = 512
|
||||||
|
heads: int = 8
|
||||||
|
dropout: float = 0.0
|
||||||
|
d_ff: int = 2048
|
||||||
|
n_layers: int = 6
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.model)
|
||||||
|
def autoregressive_model(c: Configs):
|
||||||
|
from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \
|
||||||
|
FeedbackAttention, FeedForward
|
||||||
|
|
||||||
|
return AutoregressiveModel(
|
||||||
|
c.n_tokens, c.d_model,
|
||||||
|
FeedbackTransformer(
|
||||||
|
FeedbackTransformerLayer(d_model=c.d_model,
|
||||||
|
attn=FeedbackAttention(c.heads, c.d_model, c.dropout),
|
||||||
|
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
|
||||||
|
dropout_prob=c.dropout),
|
||||||
|
c.n_layers)).to(c.device)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create experiment
|
||||||
|
experiment.create(name="feedback_transformer")
|
||||||
|
# Create configs
|
||||||
|
conf = Configs()
|
||||||
|
# Load configurations
|
||||||
|
experiment.configs(conf,
|
||||||
|
# A dictionary of configurations to override
|
||||||
|
{'tokenizer': 'character',
|
||||||
|
'text': 'tiny_shakespeare',
|
||||||
|
'optimizer.learning_rate': 1.0,
|
||||||
|
'optimizer.optimizer': 'Noam',
|
||||||
|
'prompt': 'It is',
|
||||||
|
'prompt_separator': '',
|
||||||
|
|
||||||
|
'train_loader': 'shuffled_train_loader',
|
||||||
|
'valid_loader': 'shuffled_valid_loader',
|
||||||
|
|
||||||
|
'seq_len': 64,
|
||||||
|
'epochs': 128,
|
||||||
|
'batch_size': 80,
|
||||||
|
'inner_iterations': 25})
|
||||||
|
|
||||||
|
# Set models for saving and loading
|
||||||
|
experiment.add_pytorch_models({'model': conf.model})
|
||||||
|
|
||||||
|
conf.init()
|
||||||
|
# Start the experiment
|
||||||
|
with experiment.start():
|
||||||
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Reference in New Issue
Block a user