hyperlstm

This commit is contained in:
Varuna Jayasiri
2020-12-26 15:49:44 +05:30
parent 3e798f38f9
commit 48277e9334
3 changed files with 345 additions and 0 deletions

View File

View File

@ -0,0 +1,209 @@
from typing import Callable, Any
import torch
import torch.nn as nn
from labml import lab, experiment, monit, tracker, logger
from labml.configs import option
from labml.logger import Text
from labml.utils.pytorch import get_modules
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
class AutoregressiveModel(Module):
"""
## Auto regressive model
"""
def __init__(self, n_vocab: int, d_model: int, n_rhn, n_z):
super().__init__()
# Token embedding module
self.src_embed = nn.Embedding(n_vocab, d_model, n_rhn, n_z)
self.lstm = HyperLSTM(d_model, d_model, n_rhn, n_z, 1)
self.generator = nn.Linear(d_model, n_vocab)
def __call__(self, x: torch.Tensor):
x = self.src_embed(x)
# Embed the tokens (`src`) and run it through the the transformer
res, state = self.lstm(x)
# Generate logits of the next token
return self.generator(res), state
class CrossEntropyLoss(Module):
"""
Cross entropy loss
"""
def __init__(self):
super().__init__()
self.loss = nn.CrossEntropyLoss()
def __call__(self, outputs, targets):
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
class Configs(SimpleTrainValidConfigs):
"""
## Configurations
The default configs can and will be over-ridden when we start the experiment
"""
model: AutoregressiveModel
text: TextDataset
batch_size: int = 20
seq_len: int = 512
n_tokens: int
tokenizer: Callable = 'character'
is_save_models = True
optimizer: torch.optim.Adam = 'transformer_optimizer'
accuracy = Accuracy()
loss_func = CrossEntropyLoss()
def init(self):
# Create a configurable optimizer.
# Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
optimizer = OptimizerConfigs()
optimizer.parameters = self.model.parameters()
optimizer.optimizer = 'Adam'
self.optimizer = optimizer
# Create a sequential data loader for training
self.train_loader = SequentialDataLoader(text=self.text.train,
dataset=self.text,
batch_size=self.batch_size,
seq_len=self.seq_len)
# Create a sequential data loader for validation
self.valid_loader = SequentialDataLoader(text=self.text.valid,
dataset=self.text,
batch_size=self.batch_size,
seq_len=self.seq_len)
self.state_modules = [self.accuracy]
def sample(self):
"""
Sampling function to generate samples periodically while training
"""
prompt = 'It is'
log = [(prompt, Text.subtle)]
# Sample 25 tokens
for i in monit.iterate('Sample', 25):
# Tokenize the prompt
data = self.text.text_to_i(prompt).unsqueeze(-1)
data = data.to(self.device)
# Get the model output
output, state = self.model(data)
output = output.cpu()
# Get the model prediction (greedy)
output = output.argmax(dim=-1).squeeze()
# Add the prediction to prompt
prompt += self.text.itos[output[-1]]
# Add the prediction for logging
log += [(self.text.itos[output[-1]], Text.value)]
logger.log(log)
def step(self, batch: Any, batch_idx: BatchIndex):
"""
This method is called for each batch
"""
self.model.train(self.mode.is_train)
# Get data and target labels
data, target = batch[0].to(self.device), batch[1].to(self.device)
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Run the model
output, state = self.model(data)
# Calculate loss
loss = self.loss_func(output, target)
# Calculate accuracy
self.accuracy(output, target)
# Log the loss
tracker.add("loss.", loss)
# If we are in training mode, calculate the gradients
if self.mode.is_train:
loss.backward()
self.optimizer.step()
if batch_idx.is_last:
tracker.add('model', self.model)
self.optimizer.zero_grad()
tracker.save()
def character_tokenizer(x: str):
return list(x)
@option(Configs.tokenizer)
def character():
"""
Character level tokenizer
"""
return character_tokenizer
@option(Configs.text)
def tiny_shakespeare(c: Configs):
return TextFileDataset(
lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer,
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
@option(Configs.model)
def autoregressive_model(c: Configs):
"""
Initialize the auto-regressive model
"""
m = AutoregressiveModel(c.n_tokens, 512, 16, 16)
return m.to(c.device)
def main():
# Create experiment
experiment.create(name="knn_lm", comment='')
# Create configs
conf = Configs()
# Load configurations
experiment.configs(conf,
# A dictionary of configurations to override
{'tokenizer': 'character',
'text': 'tiny_shakespeare',
'seq_len': 512,
'epochs': 128,
'batch_size': 2,
'inner_iterations': 10})
# This is needed to initialize models
conf.n_tokens = conf.text.n_tokens
# Set models for saving and loading
experiment.add_pytorch_models(get_modules(conf))
conf.init()
# Start the experiment
with experiment.start():
# `TrainValidConfigs.run`
conf.run()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,136 @@
from typing import Optional, Tuple
import torch
from labml_helpers.module import Module
from torch import nn
from labml_nn.lstm import LSTMCell
class HyperLSTMCell(Module):
def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int):
super().__init__()
self.hidden_size = hidden_size
# TODO: need layernorm
self.rhn = LSTMCell(hidden_size + input_size, rhn_hidden_size)
self.z_h = nn.Linear(rhn_hidden_size, 4 * n_z)
self.z_x = nn.Linear(rhn_hidden_size, 4 * n_z)
self.z_b = nn.Linear(rhn_hidden_size, 4 * n_z, bias=False)
d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
self.d_h = nn.ModuleList(d_h)
d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
self.d_x = nn.ModuleList(d_x)
d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
self.d_b = nn.ModuleList(d_b)
self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])
self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])
self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
def __call__(self, x: torch.Tensor,
h: torch.Tensor, c: torch.Tensor,
rhn_h: torch.Tensor, rhn_c: torch.Tensor):
rhn_x = torch.cat((h, x), dim=-1)
rhn_h, rhn_c = self.rhn(rhn_x, rhn_h, rhn_c)
z_h = self.z_h(rhn_h).chunk(4, dim=-1)
z_x = self.z_x(rhn_h).chunk(4, dim=-1)
z_b = self.z_b(rhn_h).chunk(4, dim=-1)
ifgo = []
for i in range(4):
d_h = self.d_h[i](z_h[i])
w_h = torch.einsum('ij,bi->bij', self.w_h[i], d_h)
d_x = self.d_x[i](z_x[i])
w_x = torch.einsum('ij,bi->bij', self.w_x[i], d_x)
b = self.d_b[i](z_b[i])
g = torch.einsum('bij,bj->bi', w_h, h) + \
torch.einsum('bij,bj->bi', w_x, x) + \
b
ifgo.append(self.layer_norm[i](g))
# $$i_t = \sigma\big(lin_{xi}(x_t) + lin_{hi}(h_{t-1})\big)$$
i = torch.sigmoid(ifgo[0])
# $$f_t = \sigma\big(lin_{xf}(x_t) + lin_{hf}(h_{t-1})\big)$$
f = torch.sigmoid(ifgo[1])
# $$g_t = \tanh\big(lin_{xg}(x_t) + lin_{hg}(h_{t-1})\big)$$
g = torch.tanh(ifgo[2])
# $$o_t = \sigma\big(lin_{xo}(x_t) + lin_{ho}(h_{t-1})\big)$$
o = torch.sigmoid(ifgo[3])
# $$c_t = f_t \odot c_{t-1} + i_t \odot g_t$$
c_next = f * c + i * g
# $$h_t = o_t \odot \tanh(c_t)$$
h_next = o * torch.tanh(c_next)
return h_next, c_next, rhn_h, rhn_c
class HyperLSTM(Module):
def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int, n_layers: int):
"""
Create a network of `n_layers` of LSTM.
"""
super().__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.rhn_hidden_size = rhn_hidden_size
# Create cells for each layer. Note that only the first layer gets the input directly.
# Rest of the layers get the input from the layer below
self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, rhn_hidden_size, n_z)] +
[HyperLSTMCell(hidden_size, hidden_size, rhn_hidden_size, n_z) for _ in
range(n_layers - 1)])
def __call__(self, x: torch.Tensor,
state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
"""
`x` has shape `[seq_len, batch_size, input_size]` and
`state` is a tuple of $h$ and $c$, each with a shape of `[batch_size, hidden_size]`.
"""
time_steps, batch_size = x.shape[:2]
# Initialize the state if `None`
if state is None:
h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
rhn_h = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)]
rhn_c = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)]
else:
(h, c, rhn_h, rhn_c) = state
# Reverse stack the tensors to get the states of each layer <br />
# 📝 You can just work with the tensor itself but this is easier to debug
h, c = list(torch.unbind(h)), list(torch.unbind(c))
rhn_h, rhn_c = list(torch.unbind(rhn_h)), list(torch.unbind(rhn_c))
# Array to collect the outputs of the final layer at each time step.
out = []
for t in range(time_steps):
# Input to the first layer is the input itself
inp = x[t]
# Loop through the layers
for layer in range(self.n_layers):
# Get the state of the first layer
h[layer], c[layer], rhn_h[layer], rhn_c[layer] = \
self.cells[layer](inp, h[layer], c[layer], rhn_h[layer], rhn_c[layer])
# Input to the next layer is the state of this layer
inp = h[layer]
# Collect the output $h$ of the final layer
out.append(h[-1])
# Stack the outputs and states
out = torch.stack(out)
h = torch.stack(h)
c = torch.stack(c)
rhn_h = torch.stack(rhn_h)
rhn_c = torch.stack(rhn_c)
return out, (h, c, rhn_h, rhn_c)