diff --git a/labml_nn/hypernetworks/__init__.py b/labml_nn/hypernetworks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/labml_nn/hypernetworks/experiment.py b/labml_nn/hypernetworks/experiment.py new file mode 100644 index 00000000..20c915b9 --- /dev/null +++ b/labml_nn/hypernetworks/experiment.py @@ -0,0 +1,209 @@ +from typing import Callable, Any + +import torch +import torch.nn as nn +from labml import lab, experiment, monit, tracker, logger +from labml.configs import option +from labml.logger import Text +from labml.utils.pytorch import get_modules +from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset +from labml_helpers.metrics.accuracy import Accuracy +from labml_helpers.module import Module +from labml_helpers.optimizer import OptimizerConfigs +from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex + +from labml_nn.hypernetworks.hyper_lstm import HyperLSTM + + +class AutoregressiveModel(Module): + """ + ## Auto regressive model + """ + + def __init__(self, n_vocab: int, d_model: int, n_rhn, n_z): + super().__init__() + # Token embedding module + self.src_embed = nn.Embedding(n_vocab, d_model, n_rhn, n_z) + self.lstm = HyperLSTM(d_model, d_model, n_rhn, n_z, 1) + self.generator = nn.Linear(d_model, n_vocab) + + def __call__(self, x: torch.Tensor): + x = self.src_embed(x) + # Embed the tokens (`src`) and run it through the the transformer + res, state = self.lstm(x) + # Generate logits of the next token + return self.generator(res), state + + +class CrossEntropyLoss(Module): + """ + Cross entropy loss + """ + + def __init__(self): + super().__init__() + self.loss = nn.CrossEntropyLoss() + + def __call__(self, outputs, targets): + return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) + + +class Configs(SimpleTrainValidConfigs): + """ + ## Configurations + + The default configs can and will be over-ridden when we start the experiment + """ + + model: AutoregressiveModel + text: TextDataset + batch_size: int = 20 + seq_len: int = 512 + n_tokens: int + tokenizer: Callable = 'character' + + is_save_models = True + + optimizer: torch.optim.Adam = 'transformer_optimizer' + + accuracy = Accuracy() + loss_func = CrossEntropyLoss() + + def init(self): + # Create a configurable optimizer. + # Parameters like learning rate can be changed by passing a dictionary when starting the experiment. + optimizer = OptimizerConfigs() + optimizer.parameters = self.model.parameters() + optimizer.optimizer = 'Adam' + self.optimizer = optimizer + + # Create a sequential data loader for training + self.train_loader = SequentialDataLoader(text=self.text.train, + dataset=self.text, + batch_size=self.batch_size, + seq_len=self.seq_len) + + # Create a sequential data loader for validation + self.valid_loader = SequentialDataLoader(text=self.text.valid, + dataset=self.text, + batch_size=self.batch_size, + seq_len=self.seq_len) + + self.state_modules = [self.accuracy] + + def sample(self): + """ + Sampling function to generate samples periodically while training + """ + prompt = 'It is' + log = [(prompt, Text.subtle)] + # Sample 25 tokens + for i in monit.iterate('Sample', 25): + # Tokenize the prompt + data = self.text.text_to_i(prompt).unsqueeze(-1) + data = data.to(self.device) + # Get the model output + output, state = self.model(data) + output = output.cpu() + # Get the model prediction (greedy) + output = output.argmax(dim=-1).squeeze() + # Add the prediction to prompt + prompt += self.text.itos[output[-1]] + # Add the prediction for logging + log += [(self.text.itos[output[-1]], Text.value)] + + logger.log(log) + + def step(self, batch: Any, batch_idx: BatchIndex): + """ + This method is called for each batch + """ + self.model.train(self.mode.is_train) + + # Get data and target labels + data, target = batch[0].to(self.device), batch[1].to(self.device) + + if self.mode.is_train: + tracker.add_global_step(data.shape[0] * data.shape[1]) + + # Run the model + output, state = self.model(data) + + # Calculate loss + loss = self.loss_func(output, target) + # Calculate accuracy + self.accuracy(output, target) + + # Log the loss + tracker.add("loss.", loss) + + # If we are in training mode, calculate the gradients + if self.mode.is_train: + loss.backward() + self.optimizer.step() + if batch_idx.is_last: + tracker.add('model', self.model) + self.optimizer.zero_grad() + + tracker.save() + + +def character_tokenizer(x: str): + return list(x) + + +@option(Configs.tokenizer) +def character(): + """ + Character level tokenizer + """ + return character_tokenizer + + +@option(Configs.text) +def tiny_shakespeare(c: Configs): + return TextFileDataset( + lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer, + url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt') + + +@option(Configs.model) +def autoregressive_model(c: Configs): + """ + Initialize the auto-regressive model + """ + m = AutoregressiveModel(c.n_tokens, 512, 16, 16) + return m.to(c.device) + + +def main(): + # Create experiment + experiment.create(name="knn_lm", comment='') + # Create configs + conf = Configs() + # Load configurations + experiment.configs(conf, + # A dictionary of configurations to override + {'tokenizer': 'character', + 'text': 'tiny_shakespeare', + + 'seq_len': 512, + 'epochs': 128, + 'batch_size': 2, + 'inner_iterations': 10}) + + # This is needed to initialize models + conf.n_tokens = conf.text.n_tokens + + # Set models for saving and loading + experiment.add_pytorch_models(get_modules(conf)) + + conf.init() + # Start the experiment + with experiment.start(): + # `TrainValidConfigs.run` + conf.run() + + +if __name__ == '__main__': + main() diff --git a/labml_nn/hypernetworks/hyper_lstm.py b/labml_nn/hypernetworks/hyper_lstm.py new file mode 100644 index 00000000..28bc9730 --- /dev/null +++ b/labml_nn/hypernetworks/hyper_lstm.py @@ -0,0 +1,136 @@ +from typing import Optional, Tuple + +import torch +from labml_helpers.module import Module +from torch import nn + +from labml_nn.lstm import LSTMCell + + +class HyperLSTMCell(Module): + def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int): + super().__init__() + + self.hidden_size = hidden_size + + # TODO: need layernorm + self.rhn = LSTMCell(hidden_size + input_size, rhn_hidden_size) + + self.z_h = nn.Linear(rhn_hidden_size, 4 * n_z) + self.z_x = nn.Linear(rhn_hidden_size, 4 * n_z) + self.z_b = nn.Linear(rhn_hidden_size, 4 * n_z, bias=False) + + d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] + self.d_h = nn.ModuleList(d_h) + d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] + self.d_x = nn.ModuleList(d_x) + d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)] + self.d_b = nn.ModuleList(d_b) + + self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)]) + self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)]) + + self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)]) + + def __call__(self, x: torch.Tensor, + h: torch.Tensor, c: torch.Tensor, + rhn_h: torch.Tensor, rhn_c: torch.Tensor): + rhn_x = torch.cat((h, x), dim=-1) + rhn_h, rhn_c = self.rhn(rhn_x, rhn_h, rhn_c) + + z_h = self.z_h(rhn_h).chunk(4, dim=-1) + z_x = self.z_x(rhn_h).chunk(4, dim=-1) + z_b = self.z_b(rhn_h).chunk(4, dim=-1) + + ifgo = [] + for i in range(4): + d_h = self.d_h[i](z_h[i]) + w_h = torch.einsum('ij,bi->bij', self.w_h[i], d_h) + d_x = self.d_x[i](z_x[i]) + w_x = torch.einsum('ij,bi->bij', self.w_x[i], d_x) + b = self.d_b[i](z_b[i]) + + g = torch.einsum('bij,bj->bi', w_h, h) + \ + torch.einsum('bij,bj->bi', w_x, x) + \ + b + + ifgo.append(self.layer_norm[i](g)) + + # $$i_t = \sigma\big(lin_{xi}(x_t) + lin_{hi}(h_{t-1})\big)$$ + i = torch.sigmoid(ifgo[0]) + # $$f_t = \sigma\big(lin_{xf}(x_t) + lin_{hf}(h_{t-1})\big)$$ + f = torch.sigmoid(ifgo[1]) + # $$g_t = \tanh\big(lin_{xg}(x_t) + lin_{hg}(h_{t-1})\big)$$ + g = torch.tanh(ifgo[2]) + # $$o_t = \sigma\big(lin_{xo}(x_t) + lin_{ho}(h_{t-1})\big)$$ + o = torch.sigmoid(ifgo[3]) + + # $$c_t = f_t \odot c_{t-1} + i_t \odot g_t$$ + c_next = f * c + i * g + + # $$h_t = o_t \odot \tanh(c_t)$$ + h_next = o * torch.tanh(c_next) + + return h_next, c_next, rhn_h, rhn_c + + +class HyperLSTM(Module): + def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int, n_layers: int): + """ + Create a network of `n_layers` of LSTM. + """ + + super().__init__() + self.n_layers = n_layers + self.hidden_size = hidden_size + self.rhn_hidden_size = rhn_hidden_size + # Create cells for each layer. Note that only the first layer gets the input directly. + # Rest of the layers get the input from the layer below + self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, rhn_hidden_size, n_z)] + + [HyperLSTMCell(hidden_size, hidden_size, rhn_hidden_size, n_z) for _ in + range(n_layers - 1)]) + + def __call__(self, x: torch.Tensor, + state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None): + """ + `x` has shape `[seq_len, batch_size, input_size]` and + `state` is a tuple of $h$ and $c$, each with a shape of `[batch_size, hidden_size]`. + """ + time_steps, batch_size = x.shape[:2] + + # Initialize the state if `None` + if state is None: + h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] + c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] + rhn_h = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)] + rhn_c = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)] + else: + (h, c, rhn_h, rhn_c) = state + # Reverse stack the tensors to get the states of each layer
+ # 📝 You can just work with the tensor itself but this is easier to debug + h, c = list(torch.unbind(h)), list(torch.unbind(c)) + rhn_h, rhn_c = list(torch.unbind(rhn_h)), list(torch.unbind(rhn_c)) + + # Array to collect the outputs of the final layer at each time step. + out = [] + for t in range(time_steps): + # Input to the first layer is the input itself + inp = x[t] + # Loop through the layers + for layer in range(self.n_layers): + # Get the state of the first layer + h[layer], c[layer], rhn_h[layer], rhn_c[layer] = \ + self.cells[layer](inp, h[layer], c[layer], rhn_h[layer], rhn_c[layer]) + # Input to the next layer is the state of this layer + inp = h[layer] + # Collect the output $h$ of the final layer + out.append(h[-1]) + + # Stack the outputs and states + out = torch.stack(out) + h = torch.stack(h) + c = torch.stack(c) + rhn_h = torch.stack(rhn_h) + rhn_c = torch.stack(rhn_c) + + return out, (h, c, rhn_h, rhn_c)