mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
102 lines
2.7 KiB
Python
102 lines
2.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from labml import experiment
|
|
from labml.configs import option
|
|
from labml.utils.pytorch import get_modules
|
|
|
|
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
|
from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
|
|
from labml_nn.lstm import LSTM
|
|
|
|
|
|
class AutoregressiveModel(nn.Module):
|
|
"""
|
|
## Auto regressive model
|
|
"""
|
|
|
|
def __init__(self, n_vocab: int, d_model: int, rnn_model: nn.Module):
|
|
super().__init__()
|
|
# Token embedding module
|
|
self.src_embed = nn.Embedding(n_vocab, d_model)
|
|
self.lstm = rnn_model
|
|
self.generator = nn.Linear(d_model, n_vocab)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x = self.src_embed(x)
|
|
# Embed the tokens (`src`) and run it through the the transformer
|
|
res, state = self.lstm(x)
|
|
# Generate logits of the next token
|
|
return self.generator(res), state
|
|
|
|
|
|
class Configs(NLPAutoRegressionConfigs):
|
|
"""
|
|
## Configurations
|
|
|
|
The default configs can and will be over-ridden when we start the experiment
|
|
"""
|
|
|
|
model: AutoregressiveModel
|
|
rnn_model: nn.Module
|
|
|
|
d_model: int = 512
|
|
n_rhn: int = 16
|
|
n_z: int = 16
|
|
|
|
|
|
@option(Configs.model)
|
|
def autoregressive_model(c: Configs):
|
|
"""
|
|
Initialize the auto-regressive model
|
|
"""
|
|
m = AutoregressiveModel(c.n_tokens, c.d_model, c.rnn_model)
|
|
return m.to(c.device)
|
|
|
|
|
|
@option(Configs.rnn_model)
|
|
def hyper_lstm(c: Configs):
|
|
return HyperLSTM(c.d_model, c.d_model, c.n_rhn, c.n_z, 1)
|
|
|
|
|
|
@option(Configs.rnn_model)
|
|
def lstm(c: Configs):
|
|
return LSTM(c.d_model, c.d_model, 1)
|
|
|
|
|
|
def main():
|
|
# Create experiment
|
|
experiment.create(name="hyper_lstm", comment='')
|
|
# Create configs
|
|
conf = Configs()
|
|
# Load configurations
|
|
experiment.configs(conf,
|
|
# A dictionary of configurations to override
|
|
{'tokenizer': 'character',
|
|
'text': 'tiny_shakespeare',
|
|
'optimizer.learning_rate': 2.5e-4,
|
|
'optimizer.optimizer': 'Adam',
|
|
'prompt': 'It is',
|
|
'prompt_separator': '',
|
|
|
|
'rnn_model': 'hyper_lstm',
|
|
|
|
'train_loader': 'shuffled_train_loader',
|
|
'valid_loader': 'shuffled_valid_loader',
|
|
|
|
'seq_len': 512,
|
|
'epochs': 128,
|
|
'batch_size': 2,
|
|
'inner_iterations': 25})
|
|
|
|
# Set models for saving and loading
|
|
experiment.add_pytorch_models(get_modules(conf))
|
|
|
|
# Start the experiment
|
|
with experiment.start():
|
|
# `TrainValidConfigs.run`
|
|
conf.run()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|