mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 03:43:09 +08:00
hyperlstm
This commit is contained in:
0
labml_nn/hypernetworks/__init__.py
Normal file
0
labml_nn/hypernetworks/__init__.py
Normal file
209
labml_nn/hypernetworks/experiment.py
Normal file
209
labml_nn/hypernetworks/experiment.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
from typing import Callable, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from labml import lab, experiment, monit, tracker, logger
|
||||||
|
from labml.configs import option
|
||||||
|
from labml.logger import Text
|
||||||
|
from labml.utils.pytorch import get_modules
|
||||||
|
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset
|
||||||
|
from labml_helpers.metrics.accuracy import Accuracy
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from labml_helpers.optimizer import OptimizerConfigs
|
||||||
|
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
||||||
|
|
||||||
|
from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
|
||||||
|
|
||||||
|
|
||||||
|
class AutoregressiveModel(Module):
|
||||||
|
"""
|
||||||
|
## Auto regressive model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_vocab: int, d_model: int, n_rhn, n_z):
|
||||||
|
super().__init__()
|
||||||
|
# Token embedding module
|
||||||
|
self.src_embed = nn.Embedding(n_vocab, d_model, n_rhn, n_z)
|
||||||
|
self.lstm = HyperLSTM(d_model, d_model, n_rhn, n_z, 1)
|
||||||
|
self.generator = nn.Linear(d_model, n_vocab)
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
x = self.src_embed(x)
|
||||||
|
# Embed the tokens (`src`) and run it through the the transformer
|
||||||
|
res, state = self.lstm(x)
|
||||||
|
# Generate logits of the next token
|
||||||
|
return self.generator(res), state
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropyLoss(Module):
|
||||||
|
"""
|
||||||
|
Cross entropy loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.loss = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def __call__(self, outputs, targets):
|
||||||
|
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(SimpleTrainValidConfigs):
|
||||||
|
"""
|
||||||
|
## Configurations
|
||||||
|
|
||||||
|
The default configs can and will be over-ridden when we start the experiment
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: AutoregressiveModel
|
||||||
|
text: TextDataset
|
||||||
|
batch_size: int = 20
|
||||||
|
seq_len: int = 512
|
||||||
|
n_tokens: int
|
||||||
|
tokenizer: Callable = 'character'
|
||||||
|
|
||||||
|
is_save_models = True
|
||||||
|
|
||||||
|
optimizer: torch.optim.Adam = 'transformer_optimizer'
|
||||||
|
|
||||||
|
accuracy = Accuracy()
|
||||||
|
loss_func = CrossEntropyLoss()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
# Create a configurable optimizer.
|
||||||
|
# Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
|
||||||
|
optimizer = OptimizerConfigs()
|
||||||
|
optimizer.parameters = self.model.parameters()
|
||||||
|
optimizer.optimizer = 'Adam'
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
# Create a sequential data loader for training
|
||||||
|
self.train_loader = SequentialDataLoader(text=self.text.train,
|
||||||
|
dataset=self.text,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
seq_len=self.seq_len)
|
||||||
|
|
||||||
|
# Create a sequential data loader for validation
|
||||||
|
self.valid_loader = SequentialDataLoader(text=self.text.valid,
|
||||||
|
dataset=self.text,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
seq_len=self.seq_len)
|
||||||
|
|
||||||
|
self.state_modules = [self.accuracy]
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
"""
|
||||||
|
Sampling function to generate samples periodically while training
|
||||||
|
"""
|
||||||
|
prompt = 'It is'
|
||||||
|
log = [(prompt, Text.subtle)]
|
||||||
|
# Sample 25 tokens
|
||||||
|
for i in monit.iterate('Sample', 25):
|
||||||
|
# Tokenize the prompt
|
||||||
|
data = self.text.text_to_i(prompt).unsqueeze(-1)
|
||||||
|
data = data.to(self.device)
|
||||||
|
# Get the model output
|
||||||
|
output, state = self.model(data)
|
||||||
|
output = output.cpu()
|
||||||
|
# Get the model prediction (greedy)
|
||||||
|
output = output.argmax(dim=-1).squeeze()
|
||||||
|
# Add the prediction to prompt
|
||||||
|
prompt += self.text.itos[output[-1]]
|
||||||
|
# Add the prediction for logging
|
||||||
|
log += [(self.text.itos[output[-1]], Text.value)]
|
||||||
|
|
||||||
|
logger.log(log)
|
||||||
|
|
||||||
|
def step(self, batch: Any, batch_idx: BatchIndex):
|
||||||
|
"""
|
||||||
|
This method is called for each batch
|
||||||
|
"""
|
||||||
|
self.model.train(self.mode.is_train)
|
||||||
|
|
||||||
|
# Get data and target labels
|
||||||
|
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||||
|
|
||||||
|
if self.mode.is_train:
|
||||||
|
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||||
|
|
||||||
|
# Run the model
|
||||||
|
output, state = self.model(data)
|
||||||
|
|
||||||
|
# Calculate loss
|
||||||
|
loss = self.loss_func(output, target)
|
||||||
|
# Calculate accuracy
|
||||||
|
self.accuracy(output, target)
|
||||||
|
|
||||||
|
# Log the loss
|
||||||
|
tracker.add("loss.", loss)
|
||||||
|
|
||||||
|
# If we are in training mode, calculate the gradients
|
||||||
|
if self.mode.is_train:
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
if batch_idx.is_last:
|
||||||
|
tracker.add('model', self.model)
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
tracker.save()
|
||||||
|
|
||||||
|
|
||||||
|
def character_tokenizer(x: str):
|
||||||
|
return list(x)
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.tokenizer)
|
||||||
|
def character():
|
||||||
|
"""
|
||||||
|
Character level tokenizer
|
||||||
|
"""
|
||||||
|
return character_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.text)
|
||||||
|
def tiny_shakespeare(c: Configs):
|
||||||
|
return TextFileDataset(
|
||||||
|
lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer,
|
||||||
|
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.model)
|
||||||
|
def autoregressive_model(c: Configs):
|
||||||
|
"""
|
||||||
|
Initialize the auto-regressive model
|
||||||
|
"""
|
||||||
|
m = AutoregressiveModel(c.n_tokens, 512, 16, 16)
|
||||||
|
return m.to(c.device)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create experiment
|
||||||
|
experiment.create(name="knn_lm", comment='')
|
||||||
|
# Create configs
|
||||||
|
conf = Configs()
|
||||||
|
# Load configurations
|
||||||
|
experiment.configs(conf,
|
||||||
|
# A dictionary of configurations to override
|
||||||
|
{'tokenizer': 'character',
|
||||||
|
'text': 'tiny_shakespeare',
|
||||||
|
|
||||||
|
'seq_len': 512,
|
||||||
|
'epochs': 128,
|
||||||
|
'batch_size': 2,
|
||||||
|
'inner_iterations': 10})
|
||||||
|
|
||||||
|
# This is needed to initialize models
|
||||||
|
conf.n_tokens = conf.text.n_tokens
|
||||||
|
|
||||||
|
# Set models for saving and loading
|
||||||
|
experiment.add_pytorch_models(get_modules(conf))
|
||||||
|
|
||||||
|
conf.init()
|
||||||
|
# Start the experiment
|
||||||
|
with experiment.start():
|
||||||
|
# `TrainValidConfigs.run`
|
||||||
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
136
labml_nn/hypernetworks/hyper_lstm.py
Normal file
136
labml_nn/hypernetworks/hyper_lstm.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from labml_nn.lstm import LSTMCell
|
||||||
|
|
||||||
|
|
||||||
|
class HyperLSTMCell(Module):
|
||||||
|
def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
# TODO: need layernorm
|
||||||
|
self.rhn = LSTMCell(hidden_size + input_size, rhn_hidden_size)
|
||||||
|
|
||||||
|
self.z_h = nn.Linear(rhn_hidden_size, 4 * n_z)
|
||||||
|
self.z_x = nn.Linear(rhn_hidden_size, 4 * n_z)
|
||||||
|
self.z_b = nn.Linear(rhn_hidden_size, 4 * n_z, bias=False)
|
||||||
|
|
||||||
|
d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
|
||||||
|
self.d_h = nn.ModuleList(d_h)
|
||||||
|
d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
|
||||||
|
self.d_x = nn.ModuleList(d_x)
|
||||||
|
d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
|
||||||
|
self.d_b = nn.ModuleList(d_b)
|
||||||
|
|
||||||
|
self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])
|
||||||
|
self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])
|
||||||
|
|
||||||
|
self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor,
|
||||||
|
h: torch.Tensor, c: torch.Tensor,
|
||||||
|
rhn_h: torch.Tensor, rhn_c: torch.Tensor):
|
||||||
|
rhn_x = torch.cat((h, x), dim=-1)
|
||||||
|
rhn_h, rhn_c = self.rhn(rhn_x, rhn_h, rhn_c)
|
||||||
|
|
||||||
|
z_h = self.z_h(rhn_h).chunk(4, dim=-1)
|
||||||
|
z_x = self.z_x(rhn_h).chunk(4, dim=-1)
|
||||||
|
z_b = self.z_b(rhn_h).chunk(4, dim=-1)
|
||||||
|
|
||||||
|
ifgo = []
|
||||||
|
for i in range(4):
|
||||||
|
d_h = self.d_h[i](z_h[i])
|
||||||
|
w_h = torch.einsum('ij,bi->bij', self.w_h[i], d_h)
|
||||||
|
d_x = self.d_x[i](z_x[i])
|
||||||
|
w_x = torch.einsum('ij,bi->bij', self.w_x[i], d_x)
|
||||||
|
b = self.d_b[i](z_b[i])
|
||||||
|
|
||||||
|
g = torch.einsum('bij,bj->bi', w_h, h) + \
|
||||||
|
torch.einsum('bij,bj->bi', w_x, x) + \
|
||||||
|
b
|
||||||
|
|
||||||
|
ifgo.append(self.layer_norm[i](g))
|
||||||
|
|
||||||
|
# $$i_t = \sigma\big(lin_{xi}(x_t) + lin_{hi}(h_{t-1})\big)$$
|
||||||
|
i = torch.sigmoid(ifgo[0])
|
||||||
|
# $$f_t = \sigma\big(lin_{xf}(x_t) + lin_{hf}(h_{t-1})\big)$$
|
||||||
|
f = torch.sigmoid(ifgo[1])
|
||||||
|
# $$g_t = \tanh\big(lin_{xg}(x_t) + lin_{hg}(h_{t-1})\big)$$
|
||||||
|
g = torch.tanh(ifgo[2])
|
||||||
|
# $$o_t = \sigma\big(lin_{xo}(x_t) + lin_{ho}(h_{t-1})\big)$$
|
||||||
|
o = torch.sigmoid(ifgo[3])
|
||||||
|
|
||||||
|
# $$c_t = f_t \odot c_{t-1} + i_t \odot g_t$$
|
||||||
|
c_next = f * c + i * g
|
||||||
|
|
||||||
|
# $$h_t = o_t \odot \tanh(c_t)$$
|
||||||
|
h_next = o * torch.tanh(c_next)
|
||||||
|
|
||||||
|
return h_next, c_next, rhn_h, rhn_c
|
||||||
|
|
||||||
|
|
||||||
|
class HyperLSTM(Module):
|
||||||
|
def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int, n_layers: int):
|
||||||
|
"""
|
||||||
|
Create a network of `n_layers` of LSTM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.rhn_hidden_size = rhn_hidden_size
|
||||||
|
# Create cells for each layer. Note that only the first layer gets the input directly.
|
||||||
|
# Rest of the layers get the input from the layer below
|
||||||
|
self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, rhn_hidden_size, n_z)] +
|
||||||
|
[HyperLSTMCell(hidden_size, hidden_size, rhn_hidden_size, n_z) for _ in
|
||||||
|
range(n_layers - 1)])
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor,
|
||||||
|
state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
|
||||||
|
"""
|
||||||
|
`x` has shape `[seq_len, batch_size, input_size]` and
|
||||||
|
`state` is a tuple of $h$ and $c$, each with a shape of `[batch_size, hidden_size]`.
|
||||||
|
"""
|
||||||
|
time_steps, batch_size = x.shape[:2]
|
||||||
|
|
||||||
|
# Initialize the state if `None`
|
||||||
|
if state is None:
|
||||||
|
h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
|
||||||
|
c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
|
||||||
|
rhn_h = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)]
|
||||||
|
rhn_c = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)]
|
||||||
|
else:
|
||||||
|
(h, c, rhn_h, rhn_c) = state
|
||||||
|
# Reverse stack the tensors to get the states of each layer <br />
|
||||||
|
# 📝 You can just work with the tensor itself but this is easier to debug
|
||||||
|
h, c = list(torch.unbind(h)), list(torch.unbind(c))
|
||||||
|
rhn_h, rhn_c = list(torch.unbind(rhn_h)), list(torch.unbind(rhn_c))
|
||||||
|
|
||||||
|
# Array to collect the outputs of the final layer at each time step.
|
||||||
|
out = []
|
||||||
|
for t in range(time_steps):
|
||||||
|
# Input to the first layer is the input itself
|
||||||
|
inp = x[t]
|
||||||
|
# Loop through the layers
|
||||||
|
for layer in range(self.n_layers):
|
||||||
|
# Get the state of the first layer
|
||||||
|
h[layer], c[layer], rhn_h[layer], rhn_c[layer] = \
|
||||||
|
self.cells[layer](inp, h[layer], c[layer], rhn_h[layer], rhn_c[layer])
|
||||||
|
# Input to the next layer is the state of this layer
|
||||||
|
inp = h[layer]
|
||||||
|
# Collect the output $h$ of the final layer
|
||||||
|
out.append(h[-1])
|
||||||
|
|
||||||
|
# Stack the outputs and states
|
||||||
|
out = torch.stack(out)
|
||||||
|
h = torch.stack(h)
|
||||||
|
c = torch.stack(c)
|
||||||
|
rhn_h = torch.stack(rhn_h)
|
||||||
|
rhn_c = torch.stack(rhn_c)
|
||||||
|
|
||||||
|
return out, (h, c, rhn_h, rhn_c)
|
||||||
Reference in New Issue
Block a user