mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	hyperlstm
This commit is contained in:
		
							
								
								
									
										0
									
								
								labml_nn/hypernetworks/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								labml_nn/hypernetworks/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										209
									
								
								labml_nn/hypernetworks/experiment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										209
									
								
								labml_nn/hypernetworks/experiment.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,209 @@ | |||||||
|  | from typing import Callable, Any | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from labml import lab, experiment, monit, tracker, logger | ||||||
|  | from labml.configs import option | ||||||
|  | from labml.logger import Text | ||||||
|  | from labml.utils.pytorch import get_modules | ||||||
|  | from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset | ||||||
|  | from labml_helpers.metrics.accuracy import Accuracy | ||||||
|  | from labml_helpers.module import Module | ||||||
|  | from labml_helpers.optimizer import OptimizerConfigs | ||||||
|  | from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex | ||||||
|  |  | ||||||
|  | from labml_nn.hypernetworks.hyper_lstm import HyperLSTM | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AutoregressiveModel(Module): | ||||||
|  |     """ | ||||||
|  |     ## Auto regressive model | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, n_vocab: int, d_model: int, n_rhn, n_z): | ||||||
|  |         super().__init__() | ||||||
|  |         # Token embedding module | ||||||
|  |         self.src_embed = nn.Embedding(n_vocab, d_model, n_rhn, n_z) | ||||||
|  |         self.lstm = HyperLSTM(d_model, d_model, n_rhn, n_z, 1) | ||||||
|  |         self.generator = nn.Linear(d_model, n_vocab) | ||||||
|  |  | ||||||
|  |     def __call__(self, x: torch.Tensor): | ||||||
|  |         x = self.src_embed(x) | ||||||
|  |         # Embed the tokens (`src`) and run it through the the transformer | ||||||
|  |         res, state = self.lstm(x) | ||||||
|  |         # Generate logits of the next token | ||||||
|  |         return self.generator(res), state | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CrossEntropyLoss(Module): | ||||||
|  |     """ | ||||||
|  |     Cross entropy loss | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         super().__init__() | ||||||
|  |         self.loss = nn.CrossEntropyLoss() | ||||||
|  |  | ||||||
|  |     def __call__(self, outputs, targets): | ||||||
|  |         return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Configs(SimpleTrainValidConfigs): | ||||||
|  |     """ | ||||||
|  |     ## Configurations | ||||||
|  |  | ||||||
|  |     The default configs can and will be over-ridden when we start the experiment | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     model: AutoregressiveModel | ||||||
|  |     text: TextDataset | ||||||
|  |     batch_size: int = 20 | ||||||
|  |     seq_len: int = 512 | ||||||
|  |     n_tokens: int | ||||||
|  |     tokenizer: Callable = 'character' | ||||||
|  |  | ||||||
|  |     is_save_models = True | ||||||
|  |  | ||||||
|  |     optimizer: torch.optim.Adam = 'transformer_optimizer' | ||||||
|  |  | ||||||
|  |     accuracy = Accuracy() | ||||||
|  |     loss_func = CrossEntropyLoss() | ||||||
|  |  | ||||||
|  |     def init(self): | ||||||
|  |         # Create a configurable optimizer. | ||||||
|  |         # Parameters like learning rate can be changed by passing a dictionary when starting the experiment. | ||||||
|  |         optimizer = OptimizerConfigs() | ||||||
|  |         optimizer.parameters = self.model.parameters() | ||||||
|  |         optimizer.optimizer = 'Adam' | ||||||
|  |         self.optimizer = optimizer | ||||||
|  |  | ||||||
|  |         # Create a sequential data loader for training | ||||||
|  |         self.train_loader = SequentialDataLoader(text=self.text.train, | ||||||
|  |                                                  dataset=self.text, | ||||||
|  |                                                  batch_size=self.batch_size, | ||||||
|  |                                                  seq_len=self.seq_len) | ||||||
|  |  | ||||||
|  |         # Create a sequential data loader for validation | ||||||
|  |         self.valid_loader = SequentialDataLoader(text=self.text.valid, | ||||||
|  |                                                  dataset=self.text, | ||||||
|  |                                                  batch_size=self.batch_size, | ||||||
|  |                                                  seq_len=self.seq_len) | ||||||
|  |  | ||||||
|  |         self.state_modules = [self.accuracy] | ||||||
|  |  | ||||||
|  |     def sample(self): | ||||||
|  |         """ | ||||||
|  |         Sampling function to generate samples periodically while training | ||||||
|  |         """ | ||||||
|  |         prompt = 'It is' | ||||||
|  |         log = [(prompt, Text.subtle)] | ||||||
|  |         # Sample 25 tokens | ||||||
|  |         for i in monit.iterate('Sample', 25): | ||||||
|  |             # Tokenize the prompt | ||||||
|  |             data = self.text.text_to_i(prompt).unsqueeze(-1) | ||||||
|  |             data = data.to(self.device) | ||||||
|  |             # Get the model output | ||||||
|  |             output, state = self.model(data) | ||||||
|  |             output = output.cpu() | ||||||
|  |             # Get the model prediction (greedy) | ||||||
|  |             output = output.argmax(dim=-1).squeeze() | ||||||
|  |             # Add the prediction to prompt | ||||||
|  |             prompt += self.text.itos[output[-1]] | ||||||
|  |             # Add the prediction for logging | ||||||
|  |             log += [(self.text.itos[output[-1]], Text.value)] | ||||||
|  |  | ||||||
|  |         logger.log(log) | ||||||
|  |  | ||||||
|  |     def step(self, batch: Any, batch_idx: BatchIndex): | ||||||
|  |         """ | ||||||
|  |         This method is called for each batch | ||||||
|  |         """ | ||||||
|  |         self.model.train(self.mode.is_train) | ||||||
|  |  | ||||||
|  |         # Get data and target labels | ||||||
|  |         data, target = batch[0].to(self.device), batch[1].to(self.device) | ||||||
|  |  | ||||||
|  |         if self.mode.is_train: | ||||||
|  |             tracker.add_global_step(data.shape[0] * data.shape[1]) | ||||||
|  |  | ||||||
|  |         # Run the model | ||||||
|  |         output, state = self.model(data) | ||||||
|  |  | ||||||
|  |         # Calculate loss | ||||||
|  |         loss = self.loss_func(output, target) | ||||||
|  |         # Calculate accuracy | ||||||
|  |         self.accuracy(output, target) | ||||||
|  |  | ||||||
|  |         # Log the loss | ||||||
|  |         tracker.add("loss.", loss) | ||||||
|  |  | ||||||
|  |         #  If we are in training mode, calculate the gradients | ||||||
|  |         if self.mode.is_train: | ||||||
|  |             loss.backward() | ||||||
|  |             self.optimizer.step() | ||||||
|  |             if batch_idx.is_last: | ||||||
|  |                 tracker.add('model', self.model) | ||||||
|  |             self.optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |         tracker.save() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def character_tokenizer(x: str): | ||||||
|  |     return list(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @option(Configs.tokenizer) | ||||||
|  | def character(): | ||||||
|  |     """ | ||||||
|  |     Character level tokenizer | ||||||
|  |     """ | ||||||
|  |     return character_tokenizer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @option(Configs.text) | ||||||
|  | def tiny_shakespeare(c: Configs): | ||||||
|  |     return TextFileDataset( | ||||||
|  |         lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer, | ||||||
|  |         url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @option(Configs.model) | ||||||
|  | def autoregressive_model(c: Configs): | ||||||
|  |     """ | ||||||
|  |     Initialize the auto-regressive model | ||||||
|  |     """ | ||||||
|  |     m = AutoregressiveModel(c.n_tokens, 512, 16, 16) | ||||||
|  |     return m.to(c.device) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     # Create experiment | ||||||
|  |     experiment.create(name="knn_lm", comment='') | ||||||
|  |     # Create configs | ||||||
|  |     conf = Configs() | ||||||
|  |     # Load configurations | ||||||
|  |     experiment.configs(conf, | ||||||
|  |                        # A dictionary of configurations to override | ||||||
|  |                        {'tokenizer': 'character', | ||||||
|  |                         'text': 'tiny_shakespeare', | ||||||
|  |  | ||||||
|  |                         'seq_len': 512, | ||||||
|  |                         'epochs': 128, | ||||||
|  |                         'batch_size': 2, | ||||||
|  |                         'inner_iterations': 10}) | ||||||
|  |  | ||||||
|  |     # This is needed to initialize models | ||||||
|  |     conf.n_tokens = conf.text.n_tokens | ||||||
|  |  | ||||||
|  |     # Set models for saving and loading | ||||||
|  |     experiment.add_pytorch_models(get_modules(conf)) | ||||||
|  |  | ||||||
|  |     conf.init() | ||||||
|  |     # Start the experiment | ||||||
|  |     with experiment.start(): | ||||||
|  |         # `TrainValidConfigs.run` | ||||||
|  |         conf.run() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
							
								
								
									
										136
									
								
								labml_nn/hypernetworks/hyper_lstm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								labml_nn/hypernetworks/hyper_lstm.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,136 @@ | |||||||
|  | from typing import Optional, Tuple | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from labml_helpers.module import Module | ||||||
|  | from torch import nn | ||||||
|  |  | ||||||
|  | from labml_nn.lstm import LSTMCell | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class HyperLSTMCell(Module): | ||||||
|  |     def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int): | ||||||
|  |         super().__init__() | ||||||
|  |  | ||||||
|  |         self.hidden_size = hidden_size | ||||||
|  |  | ||||||
|  |         # TODO: need layernorm | ||||||
|  |         self.rhn = LSTMCell(hidden_size + input_size, rhn_hidden_size) | ||||||
|  |  | ||||||
|  |         self.z_h = nn.Linear(rhn_hidden_size, 4 * n_z) | ||||||
|  |         self.z_x = nn.Linear(rhn_hidden_size, 4 * n_z) | ||||||
|  |         self.z_b = nn.Linear(rhn_hidden_size, 4 * n_z, bias=False) | ||||||
|  |  | ||||||
|  |         d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] | ||||||
|  |         self.d_h = nn.ModuleList(d_h) | ||||||
|  |         d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] | ||||||
|  |         self.d_x = nn.ModuleList(d_x) | ||||||
|  |         d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)] | ||||||
|  |         self.d_b = nn.ModuleList(d_b) | ||||||
|  |  | ||||||
|  |         self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)]) | ||||||
|  |         self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)]) | ||||||
|  |  | ||||||
|  |         self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)]) | ||||||
|  |  | ||||||
|  |     def __call__(self, x: torch.Tensor, | ||||||
|  |                  h: torch.Tensor, c: torch.Tensor, | ||||||
|  |                  rhn_h: torch.Tensor, rhn_c: torch.Tensor): | ||||||
|  |         rhn_x = torch.cat((h, x), dim=-1) | ||||||
|  |         rhn_h, rhn_c = self.rhn(rhn_x, rhn_h, rhn_c) | ||||||
|  |  | ||||||
|  |         z_h = self.z_h(rhn_h).chunk(4, dim=-1) | ||||||
|  |         z_x = self.z_x(rhn_h).chunk(4, dim=-1) | ||||||
|  |         z_b = self.z_b(rhn_h).chunk(4, dim=-1) | ||||||
|  |  | ||||||
|  |         ifgo = [] | ||||||
|  |         for i in range(4): | ||||||
|  |             d_h = self.d_h[i](z_h[i]) | ||||||
|  |             w_h = torch.einsum('ij,bi->bij', self.w_h[i], d_h) | ||||||
|  |             d_x = self.d_x[i](z_x[i]) | ||||||
|  |             w_x = torch.einsum('ij,bi->bij', self.w_x[i], d_x) | ||||||
|  |             b = self.d_b[i](z_b[i]) | ||||||
|  |  | ||||||
|  |             g = torch.einsum('bij,bj->bi', w_h, h) + \ | ||||||
|  |                 torch.einsum('bij,bj->bi', w_x, x) + \ | ||||||
|  |                 b | ||||||
|  |  | ||||||
|  |             ifgo.append(self.layer_norm[i](g)) | ||||||
|  |  | ||||||
|  |         # $$i_t = \sigma\big(lin_{xi}(x_t) + lin_{hi}(h_{t-1})\big)$$ | ||||||
|  |         i = torch.sigmoid(ifgo[0]) | ||||||
|  |         # $$f_t = \sigma\big(lin_{xf}(x_t) + lin_{hf}(h_{t-1})\big)$$ | ||||||
|  |         f = torch.sigmoid(ifgo[1]) | ||||||
|  |         # $$g_t = \tanh\big(lin_{xg}(x_t) + lin_{hg}(h_{t-1})\big)$$ | ||||||
|  |         g = torch.tanh(ifgo[2]) | ||||||
|  |         # $$o_t = \sigma\big(lin_{xo}(x_t) + lin_{ho}(h_{t-1})\big)$$ | ||||||
|  |         o = torch.sigmoid(ifgo[3]) | ||||||
|  |  | ||||||
|  |         # $$c_t = f_t \odot c_{t-1} + i_t \odot g_t$$ | ||||||
|  |         c_next = f * c + i * g | ||||||
|  |  | ||||||
|  |         # $$h_t = o_t \odot \tanh(c_t)$$ | ||||||
|  |         h_next = o * torch.tanh(c_next) | ||||||
|  |  | ||||||
|  |         return h_next, c_next, rhn_h, rhn_c | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class HyperLSTM(Module): | ||||||
|  |     def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int, n_layers: int): | ||||||
|  |         """ | ||||||
|  |         Create a network of `n_layers` of LSTM. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         super().__init__() | ||||||
|  |         self.n_layers = n_layers | ||||||
|  |         self.hidden_size = hidden_size | ||||||
|  |         self.rhn_hidden_size = rhn_hidden_size | ||||||
|  |         # Create cells for each layer. Note that only the first layer gets the input directly. | ||||||
|  |         # Rest of the layers get the input from the layer below | ||||||
|  |         self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, rhn_hidden_size, n_z)] + | ||||||
|  |                                    [HyperLSTMCell(hidden_size, hidden_size, rhn_hidden_size, n_z) for _ in | ||||||
|  |                                     range(n_layers - 1)]) | ||||||
|  |  | ||||||
|  |     def __call__(self, x: torch.Tensor, | ||||||
|  |                  state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None): | ||||||
|  |         """ | ||||||
|  |         `x` has shape `[seq_len, batch_size, input_size]` and | ||||||
|  |         `state` is a tuple of $h$ and $c$, each with a shape of `[batch_size, hidden_size]`. | ||||||
|  |         """ | ||||||
|  |         time_steps, batch_size = x.shape[:2] | ||||||
|  |  | ||||||
|  |         # Initialize the state if `None` | ||||||
|  |         if state is None: | ||||||
|  |             h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] | ||||||
|  |             c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] | ||||||
|  |             rhn_h = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)] | ||||||
|  |             rhn_c = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)] | ||||||
|  |         else: | ||||||
|  |             (h, c, rhn_h, rhn_c) = state | ||||||
|  |             # Reverse stack the tensors to get the states of each layer <br /> | ||||||
|  |             # 📝 You can just work with the tensor itself but this is easier to debug | ||||||
|  |             h, c = list(torch.unbind(h)), list(torch.unbind(c)) | ||||||
|  |             rhn_h, rhn_c = list(torch.unbind(rhn_h)), list(torch.unbind(rhn_c)) | ||||||
|  |  | ||||||
|  |         # Array to collect the outputs of the final layer at each time step. | ||||||
|  |         out = [] | ||||||
|  |         for t in range(time_steps): | ||||||
|  |             # Input to the first layer is the input itself | ||||||
|  |             inp = x[t] | ||||||
|  |             # Loop through the layers | ||||||
|  |             for layer in range(self.n_layers): | ||||||
|  |                 # Get the state of the first layer | ||||||
|  |                 h[layer], c[layer], rhn_h[layer], rhn_c[layer] = \ | ||||||
|  |                     self.cells[layer](inp, h[layer], c[layer], rhn_h[layer], rhn_c[layer]) | ||||||
|  |                 # Input to the next layer is the state of this layer | ||||||
|  |                 inp = h[layer] | ||||||
|  |             # Collect the output $h$ of the final layer | ||||||
|  |             out.append(h[-1]) | ||||||
|  |  | ||||||
|  |         # Stack the outputs and states | ||||||
|  |         out = torch.stack(out) | ||||||
|  |         h = torch.stack(h) | ||||||
|  |         c = torch.stack(c) | ||||||
|  |         rhn_h = torch.stack(rhn_h) | ||||||
|  |         rhn_c = torch.stack(rhn_c) | ||||||
|  |  | ||||||
|  |         return out, (h, c, rhn_h, rhn_c) | ||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri