mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	🚧 compressive transformer
This commit is contained in:
		
							
								
								
									
										183
									
								
								labml_nn/transformers/compressive/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								labml_nn/transformers/compressive/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,183 @@ | |||||||
|  | from typing import Optional, List | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from torch import nn | ||||||
|  |  | ||||||
|  | from labml_helpers.module import Module, TypedModuleList | ||||||
|  | from labml_nn.transformers.feed_forward import FeedForward | ||||||
|  | from labml_nn.transformers.mha import PrepareForMultiHeadAttention | ||||||
|  | from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention | ||||||
|  | from labml_nn.utils import clone_module_list | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Conv1dCompression(Module): | ||||||
|  |     def __init__(self, compression_ratio: int, d_model: int): | ||||||
|  |         super().__init__() | ||||||
|  |         self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_ratio, stride=compression_ratio) | ||||||
|  |  | ||||||
|  |     def forward(self, mem: torch.Tensor): | ||||||
|  |         """ | ||||||
|  |         * `mem` has shape `[seq_len, batch, d_model]` | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # Change the dimensions of `mem` so that we can run it through the convolution layer. | ||||||
|  |         # The convolution layer accepts in the form `[batch, features, sequence]` | ||||||
|  |         mem = mem.permute(1, 2, 0) | ||||||
|  |         # Get compressed memory | ||||||
|  |         c_mem = self.conv(mem) | ||||||
|  |         # Permute back to form `[seq_len, batch, d_model]` | ||||||
|  |         return c_mem.permute(2, 0, 1) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CompressiveTransformerLayer(Module): | ||||||
|  |     def __init__(self, *, | ||||||
|  |                  d_model: int, | ||||||
|  |                  self_attn: RelativeMultiHeadAttention, | ||||||
|  |                  feed_forward: FeedForward, | ||||||
|  |                  dropout_prob: float, | ||||||
|  |                  compress: Conv1dCompression): | ||||||
|  |         """ | ||||||
|  |         * `d_model` is the token embedding size | ||||||
|  |         * `self_attn` is the [self attention module](relative_mha.html) | ||||||
|  |         * `feed_forward` is the feed forward module | ||||||
|  |         * `dropout_prob` is the probability of dropping out after self attention and FFN | ||||||
|  |         """ | ||||||
|  |         super().__init__() | ||||||
|  |         self.compress = compress | ||||||
|  |         self.size = d_model | ||||||
|  |         self.self_attn = self_attn | ||||||
|  |         self.feed_forward = feed_forward | ||||||
|  |         self.dropout = nn.Dropout(dropout_prob) | ||||||
|  |         self.norm_self_attn = nn.LayerNorm([d_model]) | ||||||
|  |         self.norm_ff = nn.LayerNorm([d_model]) | ||||||
|  |  | ||||||
|  |     def with_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]): | ||||||
|  |         if mem is None: | ||||||
|  |             return z | ||||||
|  |  | ||||||
|  |         if c_mem is not None: | ||||||
|  |             mem = torch.cat((c_mem, mem), dim=0) | ||||||
|  |  | ||||||
|  |         mem = self.norm_self_attn(mem) | ||||||
|  |         return torch.cat((mem, z), dim=0) | ||||||
|  |  | ||||||
|  |     def forward(self, *, | ||||||
|  |                 x: torch.Tensor, | ||||||
|  |                 mem: Optional[torch.Tensor], | ||||||
|  |                 c_mem: Optional[torch.Tensor], | ||||||
|  |                 mask: torch.Tensor): | ||||||
|  |         """ | ||||||
|  |         * `x` are the token level feature vectors of shape `[seq_len, batch_size, d_model]` | ||||||
|  |         * `mem` are the past token level feature vectors of shape `[mem_len + c_mem_len * c, batch_size, d_model]` | ||||||
|  |         * `mask` is a matrix of shape `[seq_len, c_mem_len + mem_len + seq_len, batch_size]` or `[seq_len, c_mem_len + mem_len + seq_len, 1]`. | ||||||
|  |         `mask[i, j]` is  true if token at `i` can see token at `j`. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # Normalize the vectors before doing self attention | ||||||
|  |         z = self.norm_self_attn(x) | ||||||
|  |         m_z = self.with_memory(z, mem, c_mem) | ||||||
|  |         # Attention | ||||||
|  |         self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask) | ||||||
|  |         # Add the attention results | ||||||
|  |         x = x + self.dropout(self_attn) | ||||||
|  |  | ||||||
|  |         # Normalize for feed-forward | ||||||
|  |         z = self.norm_ff(x) | ||||||
|  |         # Pass through the feed-forward network | ||||||
|  |         ff = self.feed_forward(z) | ||||||
|  |         # Add the feed-forward results back | ||||||
|  |         x = x + self.dropout(ff) | ||||||
|  |  | ||||||
|  |         # | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CompressiveTransformer(Module): | ||||||
|  |     """ | ||||||
|  |     ## Transformer XL Model | ||||||
|  |  | ||||||
|  |     This consists of multiple transformer XL layers | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, layer: CompressiveTransformerLayer, n_layers: int): | ||||||
|  |         super().__init__() | ||||||
|  |         # Make copies of the transformer layer | ||||||
|  |         self.layers = clone_module_list(layer, n_layers) | ||||||
|  |         # Final normalization layer | ||||||
|  |         self.norm = nn.LayerNorm([layer.size]) | ||||||
|  |  | ||||||
|  |     def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor): | ||||||
|  |         """ | ||||||
|  |         * `x` are the token embeddings vectors of shape `[seq_len, batch_size, d_model]` | ||||||
|  |         * `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]`  for each layer | ||||||
|  |         * `mask` is the masking matrix | ||||||
|  |         """ | ||||||
|  |         # List to store token level feature vectors, | ||||||
|  |         # which will be the memories for the next sequential batch. | ||||||
|  |         new_mem = [] | ||||||
|  |         # Run through each transformer layer | ||||||
|  |         for i, layer in enumerate(self.layers): | ||||||
|  |             # Add to the list of feature vectors | ||||||
|  |             new_mem.append(x.detach()) | ||||||
|  |             # Memory | ||||||
|  |             m = mem[i] if mem else None | ||||||
|  |             # Memory | ||||||
|  |             cm = c_mem[i] if c_mem else None | ||||||
|  |             # Run through the transformer XL layer | ||||||
|  |             x = layer(x=x, mem=m, c_mem=cm, mask=mask) | ||||||
|  |         # Finally, normalize the vectors | ||||||
|  |         return self.norm(x), new_mem | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AttentionReconstructionLoss: | ||||||
|  |     def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]): | ||||||
|  |         self.layers = layers | ||||||
|  |         self.loss_func = nn.MSELoss() | ||||||
|  |  | ||||||
|  |     def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor): | ||||||
|  |         head_shape = x.shape[:-1] | ||||||
|  |  | ||||||
|  |         # Linear transform | ||||||
|  |         weight = pmha.linear.weight.detach() | ||||||
|  |         bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None | ||||||
|  |         x = F.linear(x, weight, bias) | ||||||
|  |  | ||||||
|  |         # Split last dimension into heads | ||||||
|  |         x = x.view(*head_shape, pmha.heads, pmha.d_k) | ||||||
|  |  | ||||||
|  |         # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]` | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |     def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): | ||||||
|  |         query = self.prepare_for_attn(layer.query, query) | ||||||
|  |         key = self.prepare_for_attn(layer.key, key) | ||||||
|  |         value = self.prepare_for_attn(layer.value, value) | ||||||
|  |  | ||||||
|  |         # Compute attention scores $Q K^\top$. | ||||||
|  |         # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`. | ||||||
|  |         scores = torch.einsum('ibhd,jbhd->ijbh', query, key) | ||||||
|  |  | ||||||
|  |         # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$ | ||||||
|  |         scores *= layer.scale | ||||||
|  |  | ||||||
|  |         # $softmax$ attention along the key sequence dimension | ||||||
|  |         # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$ | ||||||
|  |         attn = layer.softmax(scores) | ||||||
|  |  | ||||||
|  |         # Multiply by values | ||||||
|  |         # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ | ||||||
|  |         return torch.einsum("ijbh,jbhd->ibhd", attn, value) | ||||||
|  |  | ||||||
|  |     def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor): | ||||||
|  |         h = h.detach() | ||||||
|  |         mem = mem.detach() | ||||||
|  |  | ||||||
|  |         c_mem = layer.compress(mem) | ||||||
|  |  | ||||||
|  |         return self.loss_func(self.attn(layer.self_attn, h, mem, mem), | ||||||
|  |                               self.attn(layer.self_attn, h, c_mem, c_mem)) | ||||||
|  |  | ||||||
|  |     def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]): | ||||||
|  |         losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)] | ||||||
|  |         return sum(losses) | ||||||
							
								
								
									
										327
									
								
								labml_nn/transformers/compressive/experiment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										327
									
								
								labml_nn/transformers/compressive/experiment.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,327 @@ | |||||||
|  | """ | ||||||
|  | --- | ||||||
|  | title: Compressive Transformer Experiment | ||||||
|  | summary: This experiment trains a compressive transformer model on tiny Shakespeare dataset. | ||||||
|  | --- | ||||||
|  |  | ||||||
|  | # Compressive Transformer Experiment | ||||||
|  |  | ||||||
|  | This is an annotated PyTorch experiment to train a compressive transformer model. | ||||||
|  | """ | ||||||
|  | from typing import List, Tuple, NamedTuple | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  |  | ||||||
|  | from labml import experiment, tracker, monit, logger | ||||||
|  | from labml.configs import option | ||||||
|  | from labml.logger import Text | ||||||
|  | from labml_helpers.metrics.simple_state import SimpleStateModule | ||||||
|  | from labml_helpers.module import Module | ||||||
|  | from labml_helpers.train_valid import BatchIndex, hook_model_outputs | ||||||
|  | from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs | ||||||
|  | from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \ | ||||||
|  |     CompressiveTransformerLayer, Conv1dCompression | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CompressedMemory(NamedTuple): | ||||||
|  |     mem: List[torch.Tensor] | ||||||
|  |     c_mem: List[torch.Tensor] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AutoregressiveModel(Module): | ||||||
|  |     """ | ||||||
|  |     ## Auto regressive model | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer): | ||||||
|  |         super().__init__() | ||||||
|  |         # Token embedding module | ||||||
|  |         self.src_embed = nn.Embedding(n_vocab, d_model) | ||||||
|  |         # Transformer | ||||||
|  |         self.transformer = transformer | ||||||
|  |         # Final layer | ||||||
|  |         self.generator = nn.Linear(d_model, n_vocab) | ||||||
|  |         # Masks | ||||||
|  |         self.mask_x = None | ||||||
|  |         self.mask_mem = None | ||||||
|  |  | ||||||
|  |     def forward(self, x: torch.Tensor, mem: CompressedMemory): | ||||||
|  |         # Length of the memory | ||||||
|  |         if mem is not None: | ||||||
|  |             mem, c_mem = mem.mem, mem.c_mem | ||||||
|  |         else: | ||||||
|  |             mem = [] | ||||||
|  |             c_mem = [] | ||||||
|  |  | ||||||
|  |         m_len = len(mem[0]) if mem else 0 | ||||||
|  |         if c_mem: | ||||||
|  |             m_len += len(c_mem[0]) | ||||||
|  |  | ||||||
|  |         # Create a subsequent mask for tokens | ||||||
|  |         if self.mask_x is None or self.mask_x.shape[0] < len(x): | ||||||
|  |             from labml_nn.transformers.utils import subsequent_mask | ||||||
|  |             self.mask_x = subsequent_mask(len(x)).to(x.device) | ||||||
|  |         # Create an all ones (full visibility) mask for memory | ||||||
|  |         if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x): | ||||||
|  |             self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1) | ||||||
|  |  | ||||||
|  |         # Concatenate the masks if there is memory | ||||||
|  |         if m_len: | ||||||
|  |             mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1) | ||||||
|  |         # Use the subsequent mask otherwise | ||||||
|  |         else: | ||||||
|  |             mask = self.mask_x[:len(x), :len(x)] | ||||||
|  |  | ||||||
|  |         # Token embeddings | ||||||
|  |         x = self.src_embed(x) | ||||||
|  |         # Run it through the transformer | ||||||
|  |         res, mem = self.transformer(x, mem, c_mem, mask) | ||||||
|  |         # Generate logits of the next token | ||||||
|  |         res = self.generator(res) | ||||||
|  |         # | ||||||
|  |         return res, mem | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Configs(NLPAutoRegressionConfigs): | ||||||
|  |     """ | ||||||
|  |     ## Configurations | ||||||
|  |  | ||||||
|  |     The default configs can and will be over-ridden when we start the experiment | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     model: AutoregressiveModel | ||||||
|  |  | ||||||
|  |     # Token embedding size | ||||||
|  |     d_model: int = 128 | ||||||
|  |     # Number of attention heads | ||||||
|  |     heads: int = 4 | ||||||
|  |     # Dropout probability | ||||||
|  |     dropout: float = 0.0 | ||||||
|  |     # Number of features in FFN hidden layer | ||||||
|  |     d_ff: int = 256 | ||||||
|  |     # Number of transformer layers | ||||||
|  |     n_layers: int = 6 | ||||||
|  |     # Number of memories to keep | ||||||
|  |     mem_len: int = 8 | ||||||
|  |     # State module to maintain memories when switching between training and validation | ||||||
|  |     memory = SimpleStateModule() | ||||||
|  |     # Attention Reconstruction Loss | ||||||
|  |     attention_reconstruction_loss: AttentionReconstructionLoss | ||||||
|  |     # Compression ratio | ||||||
|  |     compression_ratio: int = 4 | ||||||
|  |     # Compressed memory length | ||||||
|  |     c_mem_len: int = 128 | ||||||
|  |  | ||||||
|  |     def init(self): | ||||||
|  |         # Set tracker configurations | ||||||
|  |         tracker.set_scalar("accuracy.*", True) | ||||||
|  |         tracker.set_scalar("loss.*", True) | ||||||
|  |         tracker.set_scalar("ar_loss.*", False) | ||||||
|  |         # Add a hook to log module outputs | ||||||
|  |         hook_model_outputs(self.mode, self.model, 'model') | ||||||
|  |         # This will keep the accuracy metric stats and memories separate for training and validation. | ||||||
|  |         self.state_modules = [self.accuracy, self.memory] | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def merge_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \ | ||||||
|  |             -> Tuple[CompressedMemory, List[torch.Tensor]]: | ||||||
|  |         """ | ||||||
|  |         Concatenate memories and remove old memories to keep a maximum of | ||||||
|  |         `mem_len` memories. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # If it's configured not to use memory | ||||||
|  |         if self.mem_len == 0: | ||||||
|  |             return CompressedMemory([], []), [] | ||||||
|  |  | ||||||
|  |         if mem is not None: | ||||||
|  |             mem, c_mem = mem.mem, mem.c_mem | ||||||
|  |         else: | ||||||
|  |             mem, c_mem = [], [] | ||||||
|  |         # Concatenate with old memory | ||||||
|  |         if mem: | ||||||
|  |             mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)] | ||||||
|  |         else: | ||||||
|  |             mem = new_mem | ||||||
|  |  | ||||||
|  |         if len(mem[0]) > self.mem_len: | ||||||
|  |             n_c_mem = (len(mem[0]) - self.mem_len + self.compression_ratio - 1) // self.compression_ratio | ||||||
|  |             old_mem = [] | ||||||
|  |             trunc_mem = [] | ||||||
|  |             for m in mem: | ||||||
|  |                 n_old = n_c_mem * self.compression_ratio | ||||||
|  |                 cm, m = torch.split(m, [n_old, len(m) - n_old]) | ||||||
|  |                 old_mem.append(cm) | ||||||
|  |                 trunc_mem.append(m) | ||||||
|  |             mem = trunc_mem | ||||||
|  |  | ||||||
|  |             new_c_mem = [] | ||||||
|  |             for i, layer in enumerate(self.model.transformer.layers): | ||||||
|  |                 new_c_mem.append(layer.compress(old_mem[i])) | ||||||
|  |  | ||||||
|  |             if c_mem: | ||||||
|  |                 c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)] | ||||||
|  |             else: | ||||||
|  |                 c_mem = new_c_mem | ||||||
|  |  | ||||||
|  |             # Truncate old memories | ||||||
|  |             if len(c_mem[0]) > self.c_mem_len: | ||||||
|  |                 c_mem = [m[-self.c_mem_len:] for m in c_mem] | ||||||
|  |         else: | ||||||
|  |             old_mem = [] | ||||||
|  |  | ||||||
|  |         # | ||||||
|  |         return CompressedMemory(mem, c_mem), old_mem | ||||||
|  |  | ||||||
|  |     def step(self, batch: any, batch_idx: BatchIndex): | ||||||
|  |         """ | ||||||
|  |         ### Training/validation step | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # Move data to the device | ||||||
|  |         data, target = batch[0].to(self.device), batch[1].to(self.device) | ||||||
|  |  | ||||||
|  |         # Update global step (number of tokens processed) when in training mode | ||||||
|  |         if self.mode.is_train: | ||||||
|  |             tracker.add_global_step(data.shape[0] * data.shape[1]) | ||||||
|  |  | ||||||
|  |         # Whether to capture model outputs | ||||||
|  |         with self.mode.update(is_log_activations=batch_idx.is_last): | ||||||
|  |             # Get memories | ||||||
|  |             mem = self.memory.get() | ||||||
|  |             # Run the model | ||||||
|  |             output, new_mem = self.model(data, mem) | ||||||
|  |             # Merge memory | ||||||
|  |             mem, old_mem = self.merge_memory(mem, new_mem) | ||||||
|  |             # Update memories | ||||||
|  |             self.memory.set(mem) | ||||||
|  |  | ||||||
|  |         # Calculate and log cross entropy loss | ||||||
|  |         loss = self.loss_func(output, target) | ||||||
|  |         tracker.add("loss.", loss) | ||||||
|  |  | ||||||
|  |         if old_mem: | ||||||
|  |             ar_loss = self.attention_reconstruction_loss(new_mem, old_mem) | ||||||
|  |             tracker.add("ar_loss.", ar_loss) | ||||||
|  |             # loss = loss + ar_loss | ||||||
|  |  | ||||||
|  |         # Calculate and log accuracy | ||||||
|  |         self.accuracy(output, target) | ||||||
|  |         self.accuracy.track() | ||||||
|  |  | ||||||
|  |         # Train the model | ||||||
|  |         if self.mode.is_train: | ||||||
|  |             # Calculate gradients | ||||||
|  |             loss.backward() | ||||||
|  |             # Clip gradients | ||||||
|  |             torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) | ||||||
|  |             # Take optimizer step | ||||||
|  |             self.optimizer.step() | ||||||
|  |             # Log the model parameters and gradients on last batch of every epoch | ||||||
|  |             if batch_idx.is_last: | ||||||
|  |                 tracker.add('model', self.model) | ||||||
|  |             # Clear the gradients | ||||||
|  |             self.optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |         # Save the tracked metrics | ||||||
|  |         tracker.save() | ||||||
|  |  | ||||||
|  |     def sample(self): | ||||||
|  |         """ | ||||||
|  |         ### Sampling function to generate samples periodically while training | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # Starting prompt | ||||||
|  |         prompt = self.prompt | ||||||
|  |         # Collect output for printing | ||||||
|  |         log = [(prompt, Text.subtle)] | ||||||
|  |         # memory | ||||||
|  |         mem = CompressedMemory([], []) | ||||||
|  |         # Sample 25 tokens | ||||||
|  |         for i in monit.iterate('Sample', 25): | ||||||
|  |             # Tokenize the prompt | ||||||
|  |             data = self.text.text_to_i(prompt).unsqueeze(-1) | ||||||
|  |             # Move to device | ||||||
|  |             data = data.to(self.device) | ||||||
|  |             # Get the model output | ||||||
|  |             output, new_mem = self.model(data, mem) | ||||||
|  |             # Get the model prediction (greedy) | ||||||
|  |             output = output.argmax(dim=-1).squeeze(1) | ||||||
|  |             # Add the prediction to prompt | ||||||
|  |             prompt += self.prompt_separator + self.text.itos[output[-1]] | ||||||
|  |             # Only feed the last character to model in next iteration, rest will go in as memories | ||||||
|  |             prompt = prompt[-1:] | ||||||
|  |             # Add the prediction for logging | ||||||
|  |             log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)] | ||||||
|  |             # Update memory | ||||||
|  |             mem, _ = self.merge_memory(mem, new_mem) | ||||||
|  |  | ||||||
|  |         # Print the sampled output | ||||||
|  |         logger.log(log) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @option(Configs.model) | ||||||
|  | def autoregressive_model(c: Configs): | ||||||
|  |     """ | ||||||
|  |     ### Initialize the auto-regressive model | ||||||
|  |     """ | ||||||
|  |     from labml_nn.transformers.xl import RelativeMultiHeadAttention | ||||||
|  |     from labml_nn.transformers.feed_forward import FeedForward | ||||||
|  |     m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer( | ||||||
|  |         CompressiveTransformerLayer(d_model=c.d_model, | ||||||
|  |                                     self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout), | ||||||
|  |                                     feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout), | ||||||
|  |                                     dropout_prob=c.dropout, | ||||||
|  |                                     compress=Conv1dCompression(c.compression_ratio, c.d_model)), c.n_layers)) | ||||||
|  |     return m.to(c.device) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @option(Configs.attention_reconstruction_loss) | ||||||
|  | def attention_reconstruction_loss(c: Configs): | ||||||
|  |     """ | ||||||
|  |     ### Initialize the auto-regressive model | ||||||
|  |     """ | ||||||
|  |     return AttentionReconstructionLoss(c.model.transformer.layers) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     """ | ||||||
|  |     ### Run the experiment | ||||||
|  |     """ | ||||||
|  |     # Create experiment | ||||||
|  |     experiment.create(name="compressive_transformer", comment='') | ||||||
|  |     # Create configs | ||||||
|  |     conf = Configs() | ||||||
|  |     # Load configurations | ||||||
|  |     experiment.configs(conf, | ||||||
|  |                        # A dictionary of configurations to override | ||||||
|  |                        {'tokenizer': 'character', | ||||||
|  |                         'text': 'tiny_shakespeare', | ||||||
|  |                         'optimizer.learning_rate': 2.5e-4, | ||||||
|  |                         'optimizer.optimizer': 'AdamW', | ||||||
|  |                         'prompt': 'It is', | ||||||
|  |                         'prompt_separator': '', | ||||||
|  |  | ||||||
|  |                         'train_loader': 'sequential_train_loader', | ||||||
|  |                         'valid_loader': 'sequential_valid_loader', | ||||||
|  |  | ||||||
|  |                         'seq_len': 8, | ||||||
|  |                         'mem_len': 8, | ||||||
|  |                         'epochs': 128, | ||||||
|  |                         'batch_size': 32, | ||||||
|  |                         'inner_iterations': 25, | ||||||
|  |                         }) | ||||||
|  |  | ||||||
|  |     # Set models for saving and loading | ||||||
|  |     experiment.add_pytorch_models({'model': conf.model}) | ||||||
|  |  | ||||||
|  |     # Start the experiment | ||||||
|  |     with experiment.start(): | ||||||
|  |         # `TrainValidConfigs.run` | ||||||
|  |         conf.run() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
| @ -9,13 +9,11 @@ summary: A bunch of utility functions and classes | |||||||
|  |  | ||||||
| import copy | import copy | ||||||
|  |  | ||||||
| from torch import nn | from labml_helpers.module import M, TypedModuleList | ||||||
|  |  | ||||||
| from labml_helpers.module import Module |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def clone_module_list(module: Module, n: int): | def clone_module_list(module: M, n: int) -> TypedModuleList[M]: | ||||||
|     """ |     """ | ||||||
|     ## Make a `nn.ModuleList` with clones of a given layer |     ## Make a `nn.ModuleList` with clones of a given layer | ||||||
|     """ |     """ | ||||||
|     return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) |     return TypedModuleList([copy.deepcopy(module) for _ in range(n)]) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri