mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	hyperlstm
This commit is contained in:
		
							
								
								
									
										0
									
								
								labml_nn/hypernetworks/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								labml_nn/hypernetworks/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										209
									
								
								labml_nn/hypernetworks/experiment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										209
									
								
								labml_nn/hypernetworks/experiment.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,209 @@ | ||||
| from typing import Callable, Any | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from labml import lab, experiment, monit, tracker, logger | ||||
| from labml.configs import option | ||||
| from labml.logger import Text | ||||
| from labml.utils.pytorch import get_modules | ||||
| from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset | ||||
| from labml_helpers.metrics.accuracy import Accuracy | ||||
| from labml_helpers.module import Module | ||||
| from labml_helpers.optimizer import OptimizerConfigs | ||||
| from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex | ||||
|  | ||||
| from labml_nn.hypernetworks.hyper_lstm import HyperLSTM | ||||
|  | ||||
|  | ||||
| class AutoregressiveModel(Module): | ||||
|     """ | ||||
|     ## Auto regressive model | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, n_vocab: int, d_model: int, n_rhn, n_z): | ||||
|         super().__init__() | ||||
|         # Token embedding module | ||||
|         self.src_embed = nn.Embedding(n_vocab, d_model, n_rhn, n_z) | ||||
|         self.lstm = HyperLSTM(d_model, d_model, n_rhn, n_z, 1) | ||||
|         self.generator = nn.Linear(d_model, n_vocab) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|         x = self.src_embed(x) | ||||
|         # Embed the tokens (`src`) and run it through the the transformer | ||||
|         res, state = self.lstm(x) | ||||
|         # Generate logits of the next token | ||||
|         return self.generator(res), state | ||||
|  | ||||
|  | ||||
| class CrossEntropyLoss(Module): | ||||
|     """ | ||||
|     Cross entropy loss | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.loss = nn.CrossEntropyLoss() | ||||
|  | ||||
|     def __call__(self, outputs, targets): | ||||
|         return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) | ||||
|  | ||||
|  | ||||
| class Configs(SimpleTrainValidConfigs): | ||||
|     """ | ||||
|     ## Configurations | ||||
|  | ||||
|     The default configs can and will be over-ridden when we start the experiment | ||||
|     """ | ||||
|  | ||||
|     model: AutoregressiveModel | ||||
|     text: TextDataset | ||||
|     batch_size: int = 20 | ||||
|     seq_len: int = 512 | ||||
|     n_tokens: int | ||||
|     tokenizer: Callable = 'character' | ||||
|  | ||||
|     is_save_models = True | ||||
|  | ||||
|     optimizer: torch.optim.Adam = 'transformer_optimizer' | ||||
|  | ||||
|     accuracy = Accuracy() | ||||
|     loss_func = CrossEntropyLoss() | ||||
|  | ||||
|     def init(self): | ||||
|         # Create a configurable optimizer. | ||||
|         # Parameters like learning rate can be changed by passing a dictionary when starting the experiment. | ||||
|         optimizer = OptimizerConfigs() | ||||
|         optimizer.parameters = self.model.parameters() | ||||
|         optimizer.optimizer = 'Adam' | ||||
|         self.optimizer = optimizer | ||||
|  | ||||
|         # Create a sequential data loader for training | ||||
|         self.train_loader = SequentialDataLoader(text=self.text.train, | ||||
|                                                  dataset=self.text, | ||||
|                                                  batch_size=self.batch_size, | ||||
|                                                  seq_len=self.seq_len) | ||||
|  | ||||
|         # Create a sequential data loader for validation | ||||
|         self.valid_loader = SequentialDataLoader(text=self.text.valid, | ||||
|                                                  dataset=self.text, | ||||
|                                                  batch_size=self.batch_size, | ||||
|                                                  seq_len=self.seq_len) | ||||
|  | ||||
|         self.state_modules = [self.accuracy] | ||||
|  | ||||
|     def sample(self): | ||||
|         """ | ||||
|         Sampling function to generate samples periodically while training | ||||
|         """ | ||||
|         prompt = 'It is' | ||||
|         log = [(prompt, Text.subtle)] | ||||
|         # Sample 25 tokens | ||||
|         for i in monit.iterate('Sample', 25): | ||||
|             # Tokenize the prompt | ||||
|             data = self.text.text_to_i(prompt).unsqueeze(-1) | ||||
|             data = data.to(self.device) | ||||
|             # Get the model output | ||||
|             output, state = self.model(data) | ||||
|             output = output.cpu() | ||||
|             # Get the model prediction (greedy) | ||||
|             output = output.argmax(dim=-1).squeeze() | ||||
|             # Add the prediction to prompt | ||||
|             prompt += self.text.itos[output[-1]] | ||||
|             # Add the prediction for logging | ||||
|             log += [(self.text.itos[output[-1]], Text.value)] | ||||
|  | ||||
|         logger.log(log) | ||||
|  | ||||
|     def step(self, batch: Any, batch_idx: BatchIndex): | ||||
|         """ | ||||
|         This method is called for each batch | ||||
|         """ | ||||
|         self.model.train(self.mode.is_train) | ||||
|  | ||||
|         # Get data and target labels | ||||
|         data, target = batch[0].to(self.device), batch[1].to(self.device) | ||||
|  | ||||
|         if self.mode.is_train: | ||||
|             tracker.add_global_step(data.shape[0] * data.shape[1]) | ||||
|  | ||||
|         # Run the model | ||||
|         output, state = self.model(data) | ||||
|  | ||||
|         # Calculate loss | ||||
|         loss = self.loss_func(output, target) | ||||
|         # Calculate accuracy | ||||
|         self.accuracy(output, target) | ||||
|  | ||||
|         # Log the loss | ||||
|         tracker.add("loss.", loss) | ||||
|  | ||||
|         #  If we are in training mode, calculate the gradients | ||||
|         if self.mode.is_train: | ||||
|             loss.backward() | ||||
|             self.optimizer.step() | ||||
|             if batch_idx.is_last: | ||||
|                 tracker.add('model', self.model) | ||||
|             self.optimizer.zero_grad() | ||||
|  | ||||
|         tracker.save() | ||||
|  | ||||
|  | ||||
| def character_tokenizer(x: str): | ||||
|     return list(x) | ||||
|  | ||||
|  | ||||
| @option(Configs.tokenizer) | ||||
| def character(): | ||||
|     """ | ||||
|     Character level tokenizer | ||||
|     """ | ||||
|     return character_tokenizer | ||||
|  | ||||
|  | ||||
| @option(Configs.text) | ||||
| def tiny_shakespeare(c: Configs): | ||||
|     return TextFileDataset( | ||||
|         lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer, | ||||
|         url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt') | ||||
|  | ||||
|  | ||||
| @option(Configs.model) | ||||
| def autoregressive_model(c: Configs): | ||||
|     """ | ||||
|     Initialize the auto-regressive model | ||||
|     """ | ||||
|     m = AutoregressiveModel(c.n_tokens, 512, 16, 16) | ||||
|     return m.to(c.device) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     # Create experiment | ||||
|     experiment.create(name="knn_lm", comment='') | ||||
|     # Create configs | ||||
|     conf = Configs() | ||||
|     # Load configurations | ||||
|     experiment.configs(conf, | ||||
|                        # A dictionary of configurations to override | ||||
|                        {'tokenizer': 'character', | ||||
|                         'text': 'tiny_shakespeare', | ||||
|  | ||||
|                         'seq_len': 512, | ||||
|                         'epochs': 128, | ||||
|                         'batch_size': 2, | ||||
|                         'inner_iterations': 10}) | ||||
|  | ||||
|     # This is needed to initialize models | ||||
|     conf.n_tokens = conf.text.n_tokens | ||||
|  | ||||
|     # Set models for saving and loading | ||||
|     experiment.add_pytorch_models(get_modules(conf)) | ||||
|  | ||||
|     conf.init() | ||||
|     # Start the experiment | ||||
|     with experiment.start(): | ||||
|         # `TrainValidConfigs.run` | ||||
|         conf.run() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
							
								
								
									
										136
									
								
								labml_nn/hypernetworks/hyper_lstm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								labml_nn/hypernetworks/hyper_lstm.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,136 @@ | ||||
| from typing import Optional, Tuple | ||||
|  | ||||
| import torch | ||||
| from labml_helpers.module import Module | ||||
| from torch import nn | ||||
|  | ||||
| from labml_nn.lstm import LSTMCell | ||||
|  | ||||
|  | ||||
| class HyperLSTMCell(Module): | ||||
|     def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.hidden_size = hidden_size | ||||
|  | ||||
|         # TODO: need layernorm | ||||
|         self.rhn = LSTMCell(hidden_size + input_size, rhn_hidden_size) | ||||
|  | ||||
|         self.z_h = nn.Linear(rhn_hidden_size, 4 * n_z) | ||||
|         self.z_x = nn.Linear(rhn_hidden_size, 4 * n_z) | ||||
|         self.z_b = nn.Linear(rhn_hidden_size, 4 * n_z, bias=False) | ||||
|  | ||||
|         d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] | ||||
|         self.d_h = nn.ModuleList(d_h) | ||||
|         d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] | ||||
|         self.d_x = nn.ModuleList(d_x) | ||||
|         d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)] | ||||
|         self.d_b = nn.ModuleList(d_b) | ||||
|  | ||||
|         self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)]) | ||||
|         self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)]) | ||||
|  | ||||
|         self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)]) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor, | ||||
|                  h: torch.Tensor, c: torch.Tensor, | ||||
|                  rhn_h: torch.Tensor, rhn_c: torch.Tensor): | ||||
|         rhn_x = torch.cat((h, x), dim=-1) | ||||
|         rhn_h, rhn_c = self.rhn(rhn_x, rhn_h, rhn_c) | ||||
|  | ||||
|         z_h = self.z_h(rhn_h).chunk(4, dim=-1) | ||||
|         z_x = self.z_x(rhn_h).chunk(4, dim=-1) | ||||
|         z_b = self.z_b(rhn_h).chunk(4, dim=-1) | ||||
|  | ||||
|         ifgo = [] | ||||
|         for i in range(4): | ||||
|             d_h = self.d_h[i](z_h[i]) | ||||
|             w_h = torch.einsum('ij,bi->bij', self.w_h[i], d_h) | ||||
|             d_x = self.d_x[i](z_x[i]) | ||||
|             w_x = torch.einsum('ij,bi->bij', self.w_x[i], d_x) | ||||
|             b = self.d_b[i](z_b[i]) | ||||
|  | ||||
|             g = torch.einsum('bij,bj->bi', w_h, h) + \ | ||||
|                 torch.einsum('bij,bj->bi', w_x, x) + \ | ||||
|                 b | ||||
|  | ||||
|             ifgo.append(self.layer_norm[i](g)) | ||||
|  | ||||
|         # $$i_t = \sigma\big(lin_{xi}(x_t) + lin_{hi}(h_{t-1})\big)$$ | ||||
|         i = torch.sigmoid(ifgo[0]) | ||||
|         # $$f_t = \sigma\big(lin_{xf}(x_t) + lin_{hf}(h_{t-1})\big)$$ | ||||
|         f = torch.sigmoid(ifgo[1]) | ||||
|         # $$g_t = \tanh\big(lin_{xg}(x_t) + lin_{hg}(h_{t-1})\big)$$ | ||||
|         g = torch.tanh(ifgo[2]) | ||||
|         # $$o_t = \sigma\big(lin_{xo}(x_t) + lin_{ho}(h_{t-1})\big)$$ | ||||
|         o = torch.sigmoid(ifgo[3]) | ||||
|  | ||||
|         # $$c_t = f_t \odot c_{t-1} + i_t \odot g_t$$ | ||||
|         c_next = f * c + i * g | ||||
|  | ||||
|         # $$h_t = o_t \odot \tanh(c_t)$$ | ||||
|         h_next = o * torch.tanh(c_next) | ||||
|  | ||||
|         return h_next, c_next, rhn_h, rhn_c | ||||
|  | ||||
|  | ||||
| class HyperLSTM(Module): | ||||
|     def __init__(self, input_size: int, hidden_size: int, rhn_hidden_size: int, n_z: int, n_layers: int): | ||||
|         """ | ||||
|         Create a network of `n_layers` of LSTM. | ||||
|         """ | ||||
|  | ||||
|         super().__init__() | ||||
|         self.n_layers = n_layers | ||||
|         self.hidden_size = hidden_size | ||||
|         self.rhn_hidden_size = rhn_hidden_size | ||||
|         # Create cells for each layer. Note that only the first layer gets the input directly. | ||||
|         # Rest of the layers get the input from the layer below | ||||
|         self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, rhn_hidden_size, n_z)] + | ||||
|                                    [HyperLSTMCell(hidden_size, hidden_size, rhn_hidden_size, n_z) for _ in | ||||
|                                     range(n_layers - 1)]) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor, | ||||
|                  state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None): | ||||
|         """ | ||||
|         `x` has shape `[seq_len, batch_size, input_size]` and | ||||
|         `state` is a tuple of $h$ and $c$, each with a shape of `[batch_size, hidden_size]`. | ||||
|         """ | ||||
|         time_steps, batch_size = x.shape[:2] | ||||
|  | ||||
|         # Initialize the state if `None` | ||||
|         if state is None: | ||||
|             h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] | ||||
|             c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] | ||||
|             rhn_h = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)] | ||||
|             rhn_c = [x.new_zeros(batch_size, self.rhn_hidden_size) for _ in range(self.n_layers)] | ||||
|         else: | ||||
|             (h, c, rhn_h, rhn_c) = state | ||||
|             # Reverse stack the tensors to get the states of each layer <br /> | ||||
|             # 📝 You can just work with the tensor itself but this is easier to debug | ||||
|             h, c = list(torch.unbind(h)), list(torch.unbind(c)) | ||||
|             rhn_h, rhn_c = list(torch.unbind(rhn_h)), list(torch.unbind(rhn_c)) | ||||
|  | ||||
|         # Array to collect the outputs of the final layer at each time step. | ||||
|         out = [] | ||||
|         for t in range(time_steps): | ||||
|             # Input to the first layer is the input itself | ||||
|             inp = x[t] | ||||
|             # Loop through the layers | ||||
|             for layer in range(self.n_layers): | ||||
|                 # Get the state of the first layer | ||||
|                 h[layer], c[layer], rhn_h[layer], rhn_c[layer] = \ | ||||
|                     self.cells[layer](inp, h[layer], c[layer], rhn_h[layer], rhn_c[layer]) | ||||
|                 # Input to the next layer is the state of this layer | ||||
|                 inp = h[layer] | ||||
|             # Collect the output $h$ of the final layer | ||||
|             out.append(h[-1]) | ||||
|  | ||||
|         # Stack the outputs and states | ||||
|         out = torch.stack(out) | ||||
|         h = torch.stack(h) | ||||
|         c = torch.stack(c) | ||||
|         rhn_h = torch.stack(rhn_h) | ||||
|         rhn_c = torch.stack(rhn_c) | ||||
|  | ||||
|         return out, (h, c, rhn_h, rhn_c) | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri