mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 14:29:43 +08:00 
			
		
		
		
	✨ glu variants simple experiment
This commit is contained in:
		@ -16,7 +16,6 @@ from .feed_forward import FeedForward
 | 
				
			|||||||
from .mha import MultiHeadAttention
 | 
					from .mha import MultiHeadAttention
 | 
				
			||||||
from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
 | 
					from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
 | 
				
			||||||
    Encoder, Decoder, Generator, EncoderDecoder
 | 
					    Encoder, Decoder, Generator, EncoderDecoder
 | 
				
			||||||
from .. import activations
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FeedForwardConfigs(BaseConfigs):
 | 
					class FeedForwardConfigs(BaseConfigs):
 | 
				
			||||||
@ -102,7 +101,7 @@ aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
 | 
				
			|||||||
          (FeedForwardConfigs.bias1, False),
 | 
					          (FeedForwardConfigs.bias1, False),
 | 
				
			||||||
          (FeedForwardConfigs.bias2, False),
 | 
					          (FeedForwardConfigs.bias2, False),
 | 
				
			||||||
          (FeedForwardConfigs.bias_gate, False),
 | 
					          (FeedForwardConfigs.bias_gate, False),
 | 
				
			||||||
          (FeedForwardConfigs.activation, activations.Swish()))
 | 
					          (FeedForwardConfigs.activation, nn.SiLU()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TransformerConfigs(BaseConfigs):
 | 
					class TransformerConfigs(BaseConfigs):
 | 
				
			||||||
 | 
				
			|||||||
@ -104,7 +104,7 @@ def main():
 | 
				
			|||||||
                        'inner_iterations': 10,
 | 
					                        'inner_iterations': 10,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        # GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU
 | 
					                        # GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU
 | 
				
			||||||
                        'transformer.ffn.glu_variant': 'GLU',
 | 
					                        'transformer.ffn.glu_variant': 'Bilinear',
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        # Transformer configurations
 | 
					                        # Transformer configurations
 | 
				
			||||||
                        'transformer.d_model': 256,
 | 
					                        'transformer.d_model': 256,
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										225
									
								
								labml_nn/transformers/glu_variants/simple.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								labml_nn/transformers/glu_variants/simple.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,225 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					title: Gated Linear Units and Variants
 | 
				
			||||||
 | 
					summary: >
 | 
				
			||||||
 | 
					  Train an auto-regressive transformer with Gated Linear Units and variants
 | 
				
			||||||
 | 
					  for the position-wise feedforward network (FFN).
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Train Autoregressive Transformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This trains a simple [transformer](../../) model for auto-regression.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from torch.utils.data import Dataset, DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from labml import experiment, lab, tracker, monit, logger
 | 
				
			||||||
 | 
					from labml.logger import Text
 | 
				
			||||||
 | 
					from labml.utils.download import download_file
 | 
				
			||||||
 | 
					from labml_nn.experiments.nlp_autoregression import transpose_batch
 | 
				
			||||||
 | 
					from labml_nn.optimizers.noam import Noam
 | 
				
			||||||
 | 
					from labml_nn.transformers import Encoder, MultiHeadAttention
 | 
				
			||||||
 | 
					from labml_nn.transformers.feed_forward import FeedForward
 | 
				
			||||||
 | 
					from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
 | 
				
			||||||
 | 
					from labml_nn.transformers.utils import subsequent_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AutoregressiveModel(nn.Module):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ## Auto regressive model
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        # Token embedding module
 | 
				
			||||||
 | 
					        self.src_embed = src_embed
 | 
				
			||||||
 | 
					        # Transformer based encoder
 | 
				
			||||||
 | 
					        self.encoder = encoder
 | 
				
			||||||
 | 
					        # Next token generation layer;
 | 
				
			||||||
 | 
					        # this give logits of the the next token
 | 
				
			||||||
 | 
					        self.generator = generator
 | 
				
			||||||
 | 
					        # This will be initialized on the first call
 | 
				
			||||||
 | 
					        self.src_mask = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, src: torch.Tensor):
 | 
				
			||||||
 | 
					        # Create subsequent mask, so that the transformer can only pay attention to past tokens.
 | 
				
			||||||
 | 
					        if self.src_mask is None or self.src_mask.size(0) != len(src):
 | 
				
			||||||
 | 
					            self.src_mask = subsequent_mask(len(src)).to(src.device)
 | 
				
			||||||
 | 
					        # Embed the tokens (`src`) and run it through the the transformer
 | 
				
			||||||
 | 
					        res = self.encoder(self.src_embed(src), self.src_mask)
 | 
				
			||||||
 | 
					        # Generate logits of the next token
 | 
				
			||||||
 | 
					        return self.generator(res)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class Configs:
 | 
				
			||||||
 | 
					    d_model: int = 512
 | 
				
			||||||
 | 
					    seq_len: int = 128
 | 
				
			||||||
 | 
					    batch_size: int = 32
 | 
				
			||||||
 | 
					    n_layers: int = 6
 | 
				
			||||||
 | 
					    n_heads: int = 8
 | 
				
			||||||
 | 
					    dropout: float = 0.1
 | 
				
			||||||
 | 
					    d_ff: int = 2048
 | 
				
			||||||
 | 
					    glu_variant: str = 'GLU'
 | 
				
			||||||
 | 
					    epochs: int = 5
 | 
				
			||||||
 | 
					    grad_norm_clip: float = 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TinyShakespeareDataset(Dataset):
 | 
				
			||||||
 | 
					    def __init__(self, seq_len: int):
 | 
				
			||||||
 | 
					        path = lab.get_data_path() / 'tiny_shakespeare.txt'
 | 
				
			||||||
 | 
					        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
 | 
				
			||||||
 | 
					        with open(str(path), 'r') as f:
 | 
				
			||||||
 | 
					            text = f.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        chars = list(set(text))
 | 
				
			||||||
 | 
					        self.stoi = {c: i for i, c in enumerate(chars)}
 | 
				
			||||||
 | 
					        self.itos = {i: c for i, c in enumerate(chars)}
 | 
				
			||||||
 | 
					        self.seq_len = seq_len
 | 
				
			||||||
 | 
					        self.data = self.text_to_i(text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def text_to_i(self, text: str):
 | 
				
			||||||
 | 
					        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return len(self.data) - self.seq_len - 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, idx):
 | 
				
			||||||
 | 
					        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Trainer:
 | 
				
			||||||
 | 
					    def __init__(self, configs: Configs):
 | 
				
			||||||
 | 
					        self.device = torch.device('cpu')
 | 
				
			||||||
 | 
					        if torch.cuda.is_available():
 | 
				
			||||||
 | 
					            self.device = torch.device('cuda:0')
 | 
				
			||||||
 | 
					        self.dataset = TinyShakespeareDataset(configs.seq_len)
 | 
				
			||||||
 | 
					        self.dataloader = DataLoader(self.dataset, batch_size=configs.batch_size, collate_fn=transpose_batch,
 | 
				
			||||||
 | 
					                                     shuffle=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if configs.glu_variant == 'GLU':
 | 
				
			||||||
 | 
					            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
 | 
				
			||||||
 | 
					        elif configs.glu_variant == 'Bilinear':
 | 
				
			||||||
 | 
					            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
 | 
				
			||||||
 | 
					        elif configs.glu_variant == 'ReGLU':
 | 
				
			||||||
 | 
					            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
 | 
				
			||||||
 | 
					        elif configs.glu_variant == 'GEGLU':
 | 
				
			||||||
 | 
					            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
 | 
				
			||||||
 | 
					        elif configs.glu_variant == 'SwiGLU':
 | 
				
			||||||
 | 
					            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
 | 
				
			||||||
 | 
					        elif configs.glu_variant == 'ReLU':
 | 
				
			||||||
 | 
					            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
 | 
				
			||||||
 | 
					        elif configs.glu_variant == 'GELU':
 | 
				
			||||||
 | 
					            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(f'Unknown variant {configs.glu_variant}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        n_chars = len(self.dataset.stoi)
 | 
				
			||||||
 | 
					        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
 | 
				
			||||||
 | 
					                                         Encoder(TransformerLayer(
 | 
				
			||||||
 | 
					                                             d_model=configs.d_model,
 | 
				
			||||||
 | 
					                                             self_attn=MultiHeadAttention(configs.n_heads, configs.d_model,
 | 
				
			||||||
 | 
					                                                                          configs.dropout),
 | 
				
			||||||
 | 
					                                             src_attn=None,
 | 
				
			||||||
 | 
					                                             feed_forward=ffn,
 | 
				
			||||||
 | 
					                                             dropout_prob=configs.dropout
 | 
				
			||||||
 | 
					                                         ), configs.n_layers),
 | 
				
			||||||
 | 
					                                         nn.Linear(configs.d_model, n_chars))
 | 
				
			||||||
 | 
					        self.model.to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.loss_func = nn.CrossEntropyLoss()
 | 
				
			||||||
 | 
					        self.epochs = configs.epochs
 | 
				
			||||||
 | 
					        self.grad_norm_clip = configs.grad_norm_clip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Set tracker configurations
 | 
				
			||||||
 | 
					        tracker.set_scalar("loss.*", True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sample(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ### Sampling function to generate samples periodically while training
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Starting prompt
 | 
				
			||||||
 | 
					        prompt = 'It is'
 | 
				
			||||||
 | 
					        # Collect output for printing
 | 
				
			||||||
 | 
					        log = [(prompt, Text.subtle)]
 | 
				
			||||||
 | 
					        # Sample 25 tokens
 | 
				
			||||||
 | 
					        for i in monit.iterate('Sample', 25):
 | 
				
			||||||
 | 
					            # Tokenize the prompt
 | 
				
			||||||
 | 
					            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
 | 
				
			||||||
 | 
					            data = data.to(self.device)
 | 
				
			||||||
 | 
					            # Get the model output
 | 
				
			||||||
 | 
					            output = self.model(data)
 | 
				
			||||||
 | 
					            # Get the model prediction (greedy)
 | 
				
			||||||
 | 
					            output = output.argmax(dim=-1).squeeze()
 | 
				
			||||||
 | 
					            # Add the prediction to prompt
 | 
				
			||||||
 | 
					            prompt += self.dataset.itos[output[-1].item()]
 | 
				
			||||||
 | 
					            # Add the prediction for logging
 | 
				
			||||||
 | 
					            log += [(self.dataset.itos[output[-1].item()], Text.value)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Print the sampled output
 | 
				
			||||||
 | 
					        logger.log(log)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def train(self):
 | 
				
			||||||
 | 
					        for _ in monit.loop(self.epochs):
 | 
				
			||||||
 | 
					            for i, batch in monit.enum('Train', self.dataloader):
 | 
				
			||||||
 | 
					                # Move data to the device
 | 
				
			||||||
 | 
					                data, target = batch[0].to(self.device), batch[1].to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                tracker.add_global_step(data.shape[0] * data.shape[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self.model.train()
 | 
				
			||||||
 | 
					                output = self.model(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Calculate and log loss
 | 
				
			||||||
 | 
					                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
 | 
				
			||||||
 | 
					                tracker.add("loss.train", loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Calculate gradients
 | 
				
			||||||
 | 
					                loss.backward()
 | 
				
			||||||
 | 
					                # Clip gradients
 | 
				
			||||||
 | 
					                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
 | 
				
			||||||
 | 
					                # Take optimizer step
 | 
				
			||||||
 | 
					                self.optimizer.step()
 | 
				
			||||||
 | 
					                # Log the model parameters and gradients on last batch of every epoch
 | 
				
			||||||
 | 
					                if (i + 1) % 100 == 0:
 | 
				
			||||||
 | 
					                    tracker.add('model', self.model)
 | 
				
			||||||
 | 
					                # Clear the gradients
 | 
				
			||||||
 | 
					                self.optimizer.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if (i + 1) % 100 == 0:
 | 
				
			||||||
 | 
					                    self.model.eval()
 | 
				
			||||||
 | 
					                    with torch.no_grad():
 | 
				
			||||||
 | 
					                        self.sample()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Save the tracked metrics
 | 
				
			||||||
 | 
					                if (i + 1) % 10 == 0:
 | 
				
			||||||
 | 
					                    tracker.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            experiment.save_checkpoint()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main():
 | 
				
			||||||
 | 
					    # Create experiment
 | 
				
			||||||
 | 
					    experiment.create(name="glu_variants")
 | 
				
			||||||
 | 
					    # Create configs
 | 
				
			||||||
 | 
					    configs = Configs()
 | 
				
			||||||
 | 
					    # Load configurations
 | 
				
			||||||
 | 
					    experiment.configs(dataclasses.asdict(configs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    trainer = Trainer(configs)
 | 
				
			||||||
 | 
					    experiment.add_pytorch_models({'model': trainer.model})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Start the experiment
 | 
				
			||||||
 | 
					    with experiment.start():
 | 
				
			||||||
 | 
					        # `TrainValidConfigs.run`
 | 
				
			||||||
 | 
					        trainer.train()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    main()
 | 
				
			||||||
							
								
								
									
										4
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								setup.py
									
									
									
									
									
								
							@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
setuptools.setup(
 | 
					setuptools.setup(
 | 
				
			||||||
    name='labml-nn',
 | 
					    name='labml-nn',
 | 
				
			||||||
    version='0.4.81',
 | 
					    version='0.4.82',
 | 
				
			||||||
    author="Varuna Jayasiri, Nipun Wijerathne",
 | 
					    author="Varuna Jayasiri, Nipun Wijerathne",
 | 
				
			||||||
    author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
 | 
					    author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
 | 
				
			||||||
    description="A collection of PyTorch implementations of neural network architectures and layers.",
 | 
					    description="A collection of PyTorch implementations of neural network architectures and layers.",
 | 
				
			||||||
@ -20,7 +20,7 @@ setuptools.setup(
 | 
				
			|||||||
                                               'labml_helpers', 'labml_helpers.*',
 | 
					                                               'labml_helpers', 'labml_helpers.*',
 | 
				
			||||||
                                               'test',
 | 
					                                               'test',
 | 
				
			||||||
                                               'test.*')),
 | 
					                                               'test.*')),
 | 
				
			||||||
    install_requires=['labml>=0.4.94',
 | 
					    install_requires=['labml>=0.4.97',
 | 
				
			||||||
                      'labml-helpers>=0.4.72',
 | 
					                      'labml-helpers>=0.4.72',
 | 
				
			||||||
                      'torch',
 | 
					                      'torch',
 | 
				
			||||||
                      'einops',
 | 
					                      'einops',
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user