import torch import torch.nn as nn from labml import experiment from labml.configs import option from labml.utils.pytorch import get_modules from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs from labml_nn.hypernetworks.hyper_lstm import HyperLSTM from labml_nn.lstm import LSTM class AutoregressiveModel(nn.Module): """ ## Auto regressive model """ def __init__(self, n_vocab: int, d_model: int, rnn_model: nn.Module): super().__init__() # Token embedding module self.src_embed = nn.Embedding(n_vocab, d_model) self.lstm = rnn_model self.generator = nn.Linear(d_model, n_vocab) def forward(self, x: torch.Tensor): x = self.src_embed(x) # Embed the tokens (`src`) and run it through the the transformer res, state = self.lstm(x) # Generate logits of the next token return self.generator(res), state class Configs(NLPAutoRegressionConfigs): """ ## Configurations The default configs can and will be over-ridden when we start the experiment """ model: AutoregressiveModel rnn_model: nn.Module d_model: int = 512 n_rhn: int = 16 n_z: int = 16 @option(Configs.model) def autoregressive_model(c: Configs): """ Initialize the auto-regressive model """ m = AutoregressiveModel(c.n_tokens, c.d_model, c.rnn_model) return m.to(c.device) @option(Configs.rnn_model) def hyper_lstm(c: Configs): return HyperLSTM(c.d_model, c.d_model, c.n_rhn, c.n_z, 1) @option(Configs.rnn_model) def lstm(c: Configs): return LSTM(c.d_model, c.d_model, 1) def main(): # Create experiment experiment.create(name="hyper_lstm", comment='') # Create configs conf = Configs() # Load configurations experiment.configs(conf, # A dictionary of configurations to override {'tokenizer': 'character', 'text': 'tiny_shakespeare', 'optimizer.learning_rate': 2.5e-4, 'optimizer.optimizer': 'Adam', 'prompt': 'It is', 'prompt_separator': '', 'rnn_model': 'hyper_lstm', 'train_loader': 'shuffled_train_loader', 'valid_loader': 'shuffled_valid_loader', 'seq_len': 512, 'epochs': 128, 'batch_size': 2, 'inner_iterations': 25}) # Set models for saving and loading experiment.add_pytorch_models(get_modules(conf)) # Start the experiment with experiment.start(): # `TrainValidConfigs.run` conf.run() if __name__ == '__main__': main()