mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	🚧 compressive transformer
This commit is contained in:
		
							
								
								
									
										183
									
								
								labml_nn/transformers/compressive/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								labml_nn/transformers/compressive/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,183 @@ | ||||
| from typing import Optional, List | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
|  | ||||
| from labml_helpers.module import Module, TypedModuleList | ||||
| from labml_nn.transformers.feed_forward import FeedForward | ||||
| from labml_nn.transformers.mha import PrepareForMultiHeadAttention | ||||
| from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention | ||||
| from labml_nn.utils import clone_module_list | ||||
|  | ||||
|  | ||||
| class Conv1dCompression(Module): | ||||
|     def __init__(self, compression_ratio: int, d_model: int): | ||||
|         super().__init__() | ||||
|         self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_ratio, stride=compression_ratio) | ||||
|  | ||||
|     def forward(self, mem: torch.Tensor): | ||||
|         """ | ||||
|         * `mem` has shape `[seq_len, batch, d_model]` | ||||
|         """ | ||||
|  | ||||
|         # Change the dimensions of `mem` so that we can run it through the convolution layer. | ||||
|         # The convolution layer accepts in the form `[batch, features, sequence]` | ||||
|         mem = mem.permute(1, 2, 0) | ||||
|         # Get compressed memory | ||||
|         c_mem = self.conv(mem) | ||||
|         # Permute back to form `[seq_len, batch, d_model]` | ||||
|         return c_mem.permute(2, 0, 1) | ||||
|  | ||||
|  | ||||
| class CompressiveTransformerLayer(Module): | ||||
|     def __init__(self, *, | ||||
|                  d_model: int, | ||||
|                  self_attn: RelativeMultiHeadAttention, | ||||
|                  feed_forward: FeedForward, | ||||
|                  dropout_prob: float, | ||||
|                  compress: Conv1dCompression): | ||||
|         """ | ||||
|         * `d_model` is the token embedding size | ||||
|         * `self_attn` is the [self attention module](relative_mha.html) | ||||
|         * `feed_forward` is the feed forward module | ||||
|         * `dropout_prob` is the probability of dropping out after self attention and FFN | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.compress = compress | ||||
|         self.size = d_model | ||||
|         self.self_attn = self_attn | ||||
|         self.feed_forward = feed_forward | ||||
|         self.dropout = nn.Dropout(dropout_prob) | ||||
|         self.norm_self_attn = nn.LayerNorm([d_model]) | ||||
|         self.norm_ff = nn.LayerNorm([d_model]) | ||||
|  | ||||
|     def with_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]): | ||||
|         if mem is None: | ||||
|             return z | ||||
|  | ||||
|         if c_mem is not None: | ||||
|             mem = torch.cat((c_mem, mem), dim=0) | ||||
|  | ||||
|         mem = self.norm_self_attn(mem) | ||||
|         return torch.cat((mem, z), dim=0) | ||||
|  | ||||
|     def forward(self, *, | ||||
|                 x: torch.Tensor, | ||||
|                 mem: Optional[torch.Tensor], | ||||
|                 c_mem: Optional[torch.Tensor], | ||||
|                 mask: torch.Tensor): | ||||
|         """ | ||||
|         * `x` are the token level feature vectors of shape `[seq_len, batch_size, d_model]` | ||||
|         * `mem` are the past token level feature vectors of shape `[mem_len + c_mem_len * c, batch_size, d_model]` | ||||
|         * `mask` is a matrix of shape `[seq_len, c_mem_len + mem_len + seq_len, batch_size]` or `[seq_len, c_mem_len + mem_len + seq_len, 1]`. | ||||
|         `mask[i, j]` is  true if token at `i` can see token at `j`. | ||||
|         """ | ||||
|  | ||||
|         # Normalize the vectors before doing self attention | ||||
|         z = self.norm_self_attn(x) | ||||
|         m_z = self.with_memory(z, mem, c_mem) | ||||
|         # Attention | ||||
|         self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask) | ||||
|         # Add the attention results | ||||
|         x = x + self.dropout(self_attn) | ||||
|  | ||||
|         # Normalize for feed-forward | ||||
|         z = self.norm_ff(x) | ||||
|         # Pass through the feed-forward network | ||||
|         ff = self.feed_forward(z) | ||||
|         # Add the feed-forward results back | ||||
|         x = x + self.dropout(ff) | ||||
|  | ||||
|         # | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class CompressiveTransformer(Module): | ||||
|     """ | ||||
|     ## Transformer XL Model | ||||
|  | ||||
|     This consists of multiple transformer XL layers | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, layer: CompressiveTransformerLayer, n_layers: int): | ||||
|         super().__init__() | ||||
|         # Make copies of the transformer layer | ||||
|         self.layers = clone_module_list(layer, n_layers) | ||||
|         # Final normalization layer | ||||
|         self.norm = nn.LayerNorm([layer.size]) | ||||
|  | ||||
|     def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor): | ||||
|         """ | ||||
|         * `x` are the token embeddings vectors of shape `[seq_len, batch_size, d_model]` | ||||
|         * `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]`  for each layer | ||||
|         * `mask` is the masking matrix | ||||
|         """ | ||||
|         # List to store token level feature vectors, | ||||
|         # which will be the memories for the next sequential batch. | ||||
|         new_mem = [] | ||||
|         # Run through each transformer layer | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             # Add to the list of feature vectors | ||||
|             new_mem.append(x.detach()) | ||||
|             # Memory | ||||
|             m = mem[i] if mem else None | ||||
|             # Memory | ||||
|             cm = c_mem[i] if c_mem else None | ||||
|             # Run through the transformer XL layer | ||||
|             x = layer(x=x, mem=m, c_mem=cm, mask=mask) | ||||
|         # Finally, normalize the vectors | ||||
|         return self.norm(x), new_mem | ||||
|  | ||||
|  | ||||
| class AttentionReconstructionLoss: | ||||
|     def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]): | ||||
|         self.layers = layers | ||||
|         self.loss_func = nn.MSELoss() | ||||
|  | ||||
|     def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor): | ||||
|         head_shape = x.shape[:-1] | ||||
|  | ||||
|         # Linear transform | ||||
|         weight = pmha.linear.weight.detach() | ||||
|         bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None | ||||
|         x = F.linear(x, weight, bias) | ||||
|  | ||||
|         # Split last dimension into heads | ||||
|         x = x.view(*head_shape, pmha.heads, pmha.d_k) | ||||
|  | ||||
|         # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]` | ||||
|         return x | ||||
|  | ||||
|     def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): | ||||
|         query = self.prepare_for_attn(layer.query, query) | ||||
|         key = self.prepare_for_attn(layer.key, key) | ||||
|         value = self.prepare_for_attn(layer.value, value) | ||||
|  | ||||
|         # Compute attention scores $Q K^\top$. | ||||
|         # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`. | ||||
|         scores = torch.einsum('ibhd,jbhd->ijbh', query, key) | ||||
|  | ||||
|         # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$ | ||||
|         scores *= layer.scale | ||||
|  | ||||
|         # $softmax$ attention along the key sequence dimension | ||||
|         # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$ | ||||
|         attn = layer.softmax(scores) | ||||
|  | ||||
|         # Multiply by values | ||||
|         # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ | ||||
|         return torch.einsum("ijbh,jbhd->ibhd", attn, value) | ||||
|  | ||||
|     def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor): | ||||
|         h = h.detach() | ||||
|         mem = mem.detach() | ||||
|  | ||||
|         c_mem = layer.compress(mem) | ||||
|  | ||||
|         return self.loss_func(self.attn(layer.self_attn, h, mem, mem), | ||||
|                               self.attn(layer.self_attn, h, c_mem, c_mem)) | ||||
|  | ||||
|     def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]): | ||||
|         losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)] | ||||
|         return sum(losses) | ||||
							
								
								
									
										327
									
								
								labml_nn/transformers/compressive/experiment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										327
									
								
								labml_nn/transformers/compressive/experiment.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,327 @@ | ||||
| """ | ||||
| --- | ||||
| title: Compressive Transformer Experiment | ||||
| summary: This experiment trains a compressive transformer model on tiny Shakespeare dataset. | ||||
| --- | ||||
|  | ||||
| # Compressive Transformer Experiment | ||||
|  | ||||
| This is an annotated PyTorch experiment to train a compressive transformer model. | ||||
| """ | ||||
| from typing import List, Tuple, NamedTuple | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from labml import experiment, tracker, monit, logger | ||||
| from labml.configs import option | ||||
| from labml.logger import Text | ||||
| from labml_helpers.metrics.simple_state import SimpleStateModule | ||||
| from labml_helpers.module import Module | ||||
| from labml_helpers.train_valid import BatchIndex, hook_model_outputs | ||||
| from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs | ||||
| from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \ | ||||
|     CompressiveTransformerLayer, Conv1dCompression | ||||
|  | ||||
|  | ||||
| class CompressedMemory(NamedTuple): | ||||
|     mem: List[torch.Tensor] | ||||
|     c_mem: List[torch.Tensor] | ||||
|  | ||||
|  | ||||
| class AutoregressiveModel(Module): | ||||
|     """ | ||||
|     ## Auto regressive model | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer): | ||||
|         super().__init__() | ||||
|         # Token embedding module | ||||
|         self.src_embed = nn.Embedding(n_vocab, d_model) | ||||
|         # Transformer | ||||
|         self.transformer = transformer | ||||
|         # Final layer | ||||
|         self.generator = nn.Linear(d_model, n_vocab) | ||||
|         # Masks | ||||
|         self.mask_x = None | ||||
|         self.mask_mem = None | ||||
|  | ||||
|     def forward(self, x: torch.Tensor, mem: CompressedMemory): | ||||
|         # Length of the memory | ||||
|         if mem is not None: | ||||
|             mem, c_mem = mem.mem, mem.c_mem | ||||
|         else: | ||||
|             mem = [] | ||||
|             c_mem = [] | ||||
|  | ||||
|         m_len = len(mem[0]) if mem else 0 | ||||
|         if c_mem: | ||||
|             m_len += len(c_mem[0]) | ||||
|  | ||||
|         # Create a subsequent mask for tokens | ||||
|         if self.mask_x is None or self.mask_x.shape[0] < len(x): | ||||
|             from labml_nn.transformers.utils import subsequent_mask | ||||
|             self.mask_x = subsequent_mask(len(x)).to(x.device) | ||||
|         # Create an all ones (full visibility) mask for memory | ||||
|         if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x): | ||||
|             self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1) | ||||
|  | ||||
|         # Concatenate the masks if there is memory | ||||
|         if m_len: | ||||
|             mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1) | ||||
|         # Use the subsequent mask otherwise | ||||
|         else: | ||||
|             mask = self.mask_x[:len(x), :len(x)] | ||||
|  | ||||
|         # Token embeddings | ||||
|         x = self.src_embed(x) | ||||
|         # Run it through the transformer | ||||
|         res, mem = self.transformer(x, mem, c_mem, mask) | ||||
|         # Generate logits of the next token | ||||
|         res = self.generator(res) | ||||
|         # | ||||
|         return res, mem | ||||
|  | ||||
|  | ||||
| class Configs(NLPAutoRegressionConfigs): | ||||
|     """ | ||||
|     ## Configurations | ||||
|  | ||||
|     The default configs can and will be over-ridden when we start the experiment | ||||
|     """ | ||||
|  | ||||
|     model: AutoregressiveModel | ||||
|  | ||||
|     # Token embedding size | ||||
|     d_model: int = 128 | ||||
|     # Number of attention heads | ||||
|     heads: int = 4 | ||||
|     # Dropout probability | ||||
|     dropout: float = 0.0 | ||||
|     # Number of features in FFN hidden layer | ||||
|     d_ff: int = 256 | ||||
|     # Number of transformer layers | ||||
|     n_layers: int = 6 | ||||
|     # Number of memories to keep | ||||
|     mem_len: int = 8 | ||||
|     # State module to maintain memories when switching between training and validation | ||||
|     memory = SimpleStateModule() | ||||
|     # Attention Reconstruction Loss | ||||
|     attention_reconstruction_loss: AttentionReconstructionLoss | ||||
|     # Compression ratio | ||||
|     compression_ratio: int = 4 | ||||
|     # Compressed memory length | ||||
|     c_mem_len: int = 128 | ||||
|  | ||||
|     def init(self): | ||||
|         # Set tracker configurations | ||||
|         tracker.set_scalar("accuracy.*", True) | ||||
|         tracker.set_scalar("loss.*", True) | ||||
|         tracker.set_scalar("ar_loss.*", False) | ||||
|         # Add a hook to log module outputs | ||||
|         hook_model_outputs(self.mode, self.model, 'model') | ||||
|         # This will keep the accuracy metric stats and memories separate for training and validation. | ||||
|         self.state_modules = [self.accuracy, self.memory] | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def merge_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \ | ||||
|             -> Tuple[CompressedMemory, List[torch.Tensor]]: | ||||
|         """ | ||||
|         Concatenate memories and remove old memories to keep a maximum of | ||||
|         `mem_len` memories. | ||||
|         """ | ||||
|  | ||||
|         # If it's configured not to use memory | ||||
|         if self.mem_len == 0: | ||||
|             return CompressedMemory([], []), [] | ||||
|  | ||||
|         if mem is not None: | ||||
|             mem, c_mem = mem.mem, mem.c_mem | ||||
|         else: | ||||
|             mem, c_mem = [], [] | ||||
|         # Concatenate with old memory | ||||
|         if mem: | ||||
|             mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)] | ||||
|         else: | ||||
|             mem = new_mem | ||||
|  | ||||
|         if len(mem[0]) > self.mem_len: | ||||
|             n_c_mem = (len(mem[0]) - self.mem_len + self.compression_ratio - 1) // self.compression_ratio | ||||
|             old_mem = [] | ||||
|             trunc_mem = [] | ||||
|             for m in mem: | ||||
|                 n_old = n_c_mem * self.compression_ratio | ||||
|                 cm, m = torch.split(m, [n_old, len(m) - n_old]) | ||||
|                 old_mem.append(cm) | ||||
|                 trunc_mem.append(m) | ||||
|             mem = trunc_mem | ||||
|  | ||||
|             new_c_mem = [] | ||||
|             for i, layer in enumerate(self.model.transformer.layers): | ||||
|                 new_c_mem.append(layer.compress(old_mem[i])) | ||||
|  | ||||
|             if c_mem: | ||||
|                 c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)] | ||||
|             else: | ||||
|                 c_mem = new_c_mem | ||||
|  | ||||
|             # Truncate old memories | ||||
|             if len(c_mem[0]) > self.c_mem_len: | ||||
|                 c_mem = [m[-self.c_mem_len:] for m in c_mem] | ||||
|         else: | ||||
|             old_mem = [] | ||||
|  | ||||
|         # | ||||
|         return CompressedMemory(mem, c_mem), old_mem | ||||
|  | ||||
|     def step(self, batch: any, batch_idx: BatchIndex): | ||||
|         """ | ||||
|         ### Training/validation step | ||||
|         """ | ||||
|  | ||||
|         # Move data to the device | ||||
|         data, target = batch[0].to(self.device), batch[1].to(self.device) | ||||
|  | ||||
|         # Update global step (number of tokens processed) when in training mode | ||||
|         if self.mode.is_train: | ||||
|             tracker.add_global_step(data.shape[0] * data.shape[1]) | ||||
|  | ||||
|         # Whether to capture model outputs | ||||
|         with self.mode.update(is_log_activations=batch_idx.is_last): | ||||
|             # Get memories | ||||
|             mem = self.memory.get() | ||||
|             # Run the model | ||||
|             output, new_mem = self.model(data, mem) | ||||
|             # Merge memory | ||||
|             mem, old_mem = self.merge_memory(mem, new_mem) | ||||
|             # Update memories | ||||
|             self.memory.set(mem) | ||||
|  | ||||
|         # Calculate and log cross entropy loss | ||||
|         loss = self.loss_func(output, target) | ||||
|         tracker.add("loss.", loss) | ||||
|  | ||||
|         if old_mem: | ||||
|             ar_loss = self.attention_reconstruction_loss(new_mem, old_mem) | ||||
|             tracker.add("ar_loss.", ar_loss) | ||||
|             # loss = loss + ar_loss | ||||
|  | ||||
|         # Calculate and log accuracy | ||||
|         self.accuracy(output, target) | ||||
|         self.accuracy.track() | ||||
|  | ||||
|         # Train the model | ||||
|         if self.mode.is_train: | ||||
|             # Calculate gradients | ||||
|             loss.backward() | ||||
|             # Clip gradients | ||||
|             torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) | ||||
|             # Take optimizer step | ||||
|             self.optimizer.step() | ||||
|             # Log the model parameters and gradients on last batch of every epoch | ||||
|             if batch_idx.is_last: | ||||
|                 tracker.add('model', self.model) | ||||
|             # Clear the gradients | ||||
|             self.optimizer.zero_grad() | ||||
|  | ||||
|         # Save the tracked metrics | ||||
|         tracker.save() | ||||
|  | ||||
|     def sample(self): | ||||
|         """ | ||||
|         ### Sampling function to generate samples periodically while training | ||||
|         """ | ||||
|  | ||||
|         # Starting prompt | ||||
|         prompt = self.prompt | ||||
|         # Collect output for printing | ||||
|         log = [(prompt, Text.subtle)] | ||||
|         # memory | ||||
|         mem = CompressedMemory([], []) | ||||
|         # Sample 25 tokens | ||||
|         for i in monit.iterate('Sample', 25): | ||||
|             # Tokenize the prompt | ||||
|             data = self.text.text_to_i(prompt).unsqueeze(-1) | ||||
|             # Move to device | ||||
|             data = data.to(self.device) | ||||
|             # Get the model output | ||||
|             output, new_mem = self.model(data, mem) | ||||
|             # Get the model prediction (greedy) | ||||
|             output = output.argmax(dim=-1).squeeze(1) | ||||
|             # Add the prediction to prompt | ||||
|             prompt += self.prompt_separator + self.text.itos[output[-1]] | ||||
|             # Only feed the last character to model in next iteration, rest will go in as memories | ||||
|             prompt = prompt[-1:] | ||||
|             # Add the prediction for logging | ||||
|             log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)] | ||||
|             # Update memory | ||||
|             mem, _ = self.merge_memory(mem, new_mem) | ||||
|  | ||||
|         # Print the sampled output | ||||
|         logger.log(log) | ||||
|  | ||||
|  | ||||
| @option(Configs.model) | ||||
| def autoregressive_model(c: Configs): | ||||
|     """ | ||||
|     ### Initialize the auto-regressive model | ||||
|     """ | ||||
|     from labml_nn.transformers.xl import RelativeMultiHeadAttention | ||||
|     from labml_nn.transformers.feed_forward import FeedForward | ||||
|     m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer( | ||||
|         CompressiveTransformerLayer(d_model=c.d_model, | ||||
|                                     self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout), | ||||
|                                     feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout), | ||||
|                                     dropout_prob=c.dropout, | ||||
|                                     compress=Conv1dCompression(c.compression_ratio, c.d_model)), c.n_layers)) | ||||
|     return m.to(c.device) | ||||
|  | ||||
|  | ||||
| @option(Configs.attention_reconstruction_loss) | ||||
| def attention_reconstruction_loss(c: Configs): | ||||
|     """ | ||||
|     ### Initialize the auto-regressive model | ||||
|     """ | ||||
|     return AttentionReconstructionLoss(c.model.transformer.layers) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     """ | ||||
|     ### Run the experiment | ||||
|     """ | ||||
|     # Create experiment | ||||
|     experiment.create(name="compressive_transformer", comment='') | ||||
|     # Create configs | ||||
|     conf = Configs() | ||||
|     # Load configurations | ||||
|     experiment.configs(conf, | ||||
|                        # A dictionary of configurations to override | ||||
|                        {'tokenizer': 'character', | ||||
|                         'text': 'tiny_shakespeare', | ||||
|                         'optimizer.learning_rate': 2.5e-4, | ||||
|                         'optimizer.optimizer': 'AdamW', | ||||
|                         'prompt': 'It is', | ||||
|                         'prompt_separator': '', | ||||
|  | ||||
|                         'train_loader': 'sequential_train_loader', | ||||
|                         'valid_loader': 'sequential_valid_loader', | ||||
|  | ||||
|                         'seq_len': 8, | ||||
|                         'mem_len': 8, | ||||
|                         'epochs': 128, | ||||
|                         'batch_size': 32, | ||||
|                         'inner_iterations': 25, | ||||
|                         }) | ||||
|  | ||||
|     # Set models for saving and loading | ||||
|     experiment.add_pytorch_models({'model': conf.model}) | ||||
|  | ||||
|     # Start the experiment | ||||
|     with experiment.start(): | ||||
|         # `TrainValidConfigs.run` | ||||
|         conf.run() | ||||
|  | ||||
|  | ||||
| # | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
| @ -9,13 +9,11 @@ summary: A bunch of utility functions and classes | ||||
|  | ||||
| import copy | ||||
|  | ||||
| from torch import nn | ||||
|  | ||||
| from labml_helpers.module import Module | ||||
| from labml_helpers.module import M, TypedModuleList | ||||
|  | ||||
|  | ||||
| def clone_module_list(module: Module, n: int): | ||||
| def clone_module_list(module: M, n: int) -> TypedModuleList[M]: | ||||
|     """ | ||||
|     ## Make a `nn.ModuleList` with clones of a given layer | ||||
|     """ | ||||
|     return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) | ||||
|     return TypedModuleList([copy.deepcopy(module) for _ in range(n)]) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri