diff --git a/docs/transformers/feed_forward.html b/docs/transformers/feed_forward.html index 145ec521..9420233f 100644 --- a/docs/transformers/feed_forward.html +++ b/docs/transformers/feed_forward.html @@ -226,7 +226,7 @@ be multiplied by the gate, parameterized by weight $V$ and bias $c$

-
81    def __call__(self, x: torch.Tensor):
+
81    def forward(self, x: torch.Tensor):
diff --git a/docs/transformers/feedback/experiment.html b/docs/transformers/feedback/experiment.html index bfdefdeb..b44c76e7 100644 --- a/docs/transformers/feedback/experiment.html +++ b/docs/transformers/feedback/experiment.html @@ -138,7 +138,7 @@ where the keys and values are precalculated.

-
44    def __call__(self, x: torch.Tensor):
+
44    def forward(self, x: torch.Tensor):
diff --git a/docs/transformers/feedback/index.html b/docs/transformers/feedback/index.html index 0cb8209d..69f8e480 100644 --- a/docs/transformers/feedback/index.html +++ b/docs/transformers/feedback/index.html @@ -421,7 +421,7 @@ A_{j} &= Q^\top K_j \\
-
158    def __call__(self, *,
+                
158    def forward(self, *,
 159                 query: torch.Tensor,
 160                 key: torch.Tensor,
 161                 value: torch.Tensor):
@@ -609,7 +609,7 @@ Results in a tensor of shape [seq_len, batch_size, heads]

-
229    def __call__(self, *,
+                
229    def forward(self, *,
 230                 x: torch.Tensor,
 231                 key: Optional[torch.Tensor],
 232                 value: Optional[torch.Tensor]):
@@ -794,7 +794,7 @@ This is the weights parameter for that.

-
275    def __call__(self, x_seq: torch.Tensor):
+
275    def forward(self, x_seq: torch.Tensor):
@@ -1458,7 +1458,7 @@ This is the weights parameter for that.

-
473    def __call__(self, x_seq: torch.Tensor):
+
473    def forward(self, x_seq: torch.Tensor):
diff --git a/docs/transformers/glu_variants/experiment.html b/docs/transformers/glu_variants/experiment.html index c4edd1c6..57e7b31e 100644 --- a/docs/transformers/glu_variants/experiment.html +++ b/docs/transformers/glu_variants/experiment.html @@ -165,7 +165,7 @@ this give logits of the the next token

-
44    def __call__(self, src: torch.Tensor):
+
44    def forward(self, src: torch.Tensor):
diff --git a/docs/transformers/glu_variants/simple.html b/docs/transformers/glu_variants/simple.html index 3aefe0e7..5cddd021 100644 --- a/docs/transformers/glu_variants/simple.html +++ b/docs/transformers/glu_variants/simple.html @@ -84,18 +84,19 @@ We decided to write a simpler implementation to make it easier readers who are n
20import dataclasses
 21
 22import torch
-23from torch import nn
-24from torch.utils.data import Dataset, DataLoader
-25
-26from labml import experiment, lab, tracker, monit, logger
-27from labml.logger import Text
-28from labml.utils.download import download_file
-29from labml_nn.experiments.nlp_autoregression import transpose_batch
-30from labml_nn.optimizers.noam import Noam
-31from labml_nn.transformers import Encoder, MultiHeadAttention
-32from labml_nn.transformers.feed_forward import FeedForward
-33from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
-34from labml_nn.transformers.utils import subsequent_mask
+23from labml_helpers.module import Module +24from torch import nn +25from torch.utils.data import Dataset, DataLoader +26 +27from labml import experiment, lab, tracker, monit, logger +28from labml.logger import Text +29from labml.utils.download import download_file +30from labml_nn.experiments.nlp_autoregression import transpose_batch +31from labml_nn.optimizers.noam import Noam +32from labml_nn.transformers import Encoder, MultiHeadAttention +33from labml_nn.transformers.feed_forward import FeedForward +34from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer +35from labml_nn.transformers.utils import subsequent_mask
@@ -106,7 +107,7 @@ We decided to write a simpler implementation to make it easier readers who are n

Auto regressive model

-
37class AutoregressiveModel(nn.Module):
+
38class AutoregressiveModel(Module):
@@ -117,8 +118,8 @@ We decided to write a simpler implementation to make it easier readers who are n
-
42    def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
-43        super().__init__()
+
43    def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
+44        super().__init__()
@@ -129,7 +130,7 @@ We decided to write a simpler implementation to make it easier readers who are n

Token embedding module

-
45        self.src_embed = src_embed
+
46        self.src_embed = src_embed
@@ -140,7 +141,7 @@ We decided to write a simpler implementation to make it easier readers who are n

Transformer based encoder

-
47        self.encoder = encoder
+
48        self.encoder = encoder
@@ -152,7 +153,7 @@ We decided to write a simpler implementation to make it easier readers who are n this give logits of the the next token

-
50        self.generator = generator
+
51        self.generator = generator
@@ -163,7 +164,7 @@ this give logits of the the next token

This will be initialized on the first call

-
52        self.src_mask = None
+
53        self.src_mask = None
@@ -174,7 +175,7 @@ this give logits of the the next token

-
54    def __call__(self, src: torch.Tensor):
+
55    def forward(self, src: torch.Tensor):
@@ -185,8 +186,8 @@ this give logits of the the next token

Create subsequent mask, so that the transformer can only pay attention to past tokens.

-
56        if self.src_mask is None or self.src_mask.size(0) != len(src):
-57            self.src_mask = subsequent_mask(len(src)).to(src.device)
+
57        if self.src_mask is None or self.src_mask.size(0) != len(src):
+58            self.src_mask = subsequent_mask(len(src)).to(src.device)
@@ -197,7 +198,7 @@ this give logits of the the next token

Embed the tokens (src) and run it through the the transformer

-
59        res = self.encoder(self.src_embed(src), self.src_mask)
+
60        res = self.encoder(self.src_embed(src), self.src_mask)
@@ -208,7 +209,7 @@ this give logits of the the next token

Generate logits of the next token

-
61        return self.generator(res)
+
62        return self.generator(res)
@@ -219,8 +220,8 @@ this give logits of the the next token

Configurations

-
64@dataclasses.dataclass
-65class Configs:
+
65@dataclasses.dataclass
+66class Configs:
@@ -231,16 +232,16 @@ this give logits of the the next token

-
69    d_model: int = 512
-70    seq_len: int = 128
-71    batch_size: int = 32
-72    n_layers: int = 6
-73    n_heads: int = 8
-74    dropout: float = 0.1
-75    d_ff: int = 2048
-76    glu_variant: str = 'GLU'
-77    epochs: int = 5
-78    grad_norm_clip: float = 0.5
+
70    d_model: int = 512
+71    seq_len: int = 128
+72    batch_size: int = 32
+73    n_layers: int = 6
+74    n_heads: int = 8
+75    dropout: float = 0.1
+76    d_ff: int = 2048
+77    glu_variant: str = 'GLU'
+78    epochs: int = 5
+79    grad_norm_clip: float = 0.5
@@ -251,7 +252,7 @@ this give logits of the the next token

Tiny Shakespeare Dataset

-
81class TinyShakespeareDataset(Dataset):
+
82class TinyShakespeareDataset(Dataset):
@@ -262,7 +263,7 @@ this give logits of the the next token

-
86    def __init__(self, seq_len: int):
+
87    def __init__(self, seq_len: int):
@@ -273,7 +274,7 @@ this give logits of the the next token

Location of the text file

-
88        path = lab.get_data_path() / 'tiny_shakespeare.txt'
+
89        path = lab.get_data_path() / 'tiny_shakespeare.txt'
@@ -284,7 +285,7 @@ this give logits of the the next token

Download the file

-
90        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
+
91        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
@@ -295,8 +296,8 @@ this give logits of the the next token

Read the downloaded file

-
92        with open(str(path), 'r') as f:
-93            text = f.read()
+
93        with open(str(path), 'r') as f:
+94            text = f.read()
@@ -307,7 +308,7 @@ this give logits of the the next token

Extract the characters

-
96        chars = list(set(text))
+
97        chars = list(set(text))
@@ -318,7 +319,7 @@ this give logits of the the next token

Character to id (integer) map

-
98        self.stoi = {c: i for i, c in enumerate(chars)}
+
99        self.stoi = {c: i for i, c in enumerate(chars)}
@@ -329,7 +330,7 @@ this give logits of the the next token

Id to character map

-
100        self.itos = {i: c for i, c in enumerate(chars)}
+
101        self.itos = {i: c for i, c in enumerate(chars)}
@@ -340,7 +341,7 @@ this give logits of the the next token

Length of a training sample

-
102        self.seq_len = seq_len
+
103        self.seq_len = seq_len
@@ -351,7 +352,7 @@ this give logits of the the next token

Data in the form of a tensor of ids

-
104        self.data = self.text_to_i(text)
+
105        self.data = self.text_to_i(text)
@@ -362,7 +363,7 @@ this give logits of the the next token

Transform the text into a tensor of ids

-
106    def text_to_i(self, text: str):
+
107    def text_to_i(self, text: str):
@@ -373,7 +374,7 @@ this give logits of the the next token

-
110        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
+
111        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
@@ -385,7 +386,7 @@ this give logits of the the next token

This will read the dataset seq_len times in a single epoch.

-
112    def __len__(self):
+
113    def __len__(self):
@@ -396,7 +397,7 @@ this give logits of the the next token

-
118        return len(self.data) - self.seq_len - 1
+
119        return len(self.data) - self.seq_len - 1
@@ -407,7 +408,7 @@ this give logits of the the next token

Return a sample

-
120    def __getitem__(self, idx):
+
121    def __getitem__(self, idx):
@@ -418,7 +419,7 @@ this give logits of the the next token

-
124        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
+
125        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
@@ -429,7 +430,7 @@ this give logits of the the next token

Trainer

-
127class Trainer:
+
128class Trainer:
@@ -440,7 +441,7 @@ this give logits of the the next token

-
132    def __init__(self, configs: Configs):
+
133    def __init__(self, configs: Configs):
@@ -451,9 +452,9 @@ this give logits of the the next token

Get the device

-
134        self.device = torch.device('cpu')
-135        if torch.cuda.is_available():
-136            self.device = torch.device('cuda:0')
+
135        self.device = torch.device('cpu')
+136        if torch.cuda.is_available():
+137            self.device = torch.device('cuda:0')
@@ -464,7 +465,7 @@ this give logits of the the next token

Initialize the dataset

-
138        self.dataset = TinyShakespeareDataset(configs.seq_len)
+
139        self.dataset = TinyShakespeareDataset(configs.seq_len)
@@ -475,10 +476,10 @@ this give logits of the the next token

Initialize the dataloader

-
140        self.dataloader = DataLoader(self.dataset,
-141                                     batch_size=configs.batch_size,
-142                                     collate_fn=transpose_batch,
-143                                     shuffle=True)
+
141        self.dataloader = DataLoader(self.dataset,
+142                                     batch_size=configs.batch_size,
+143                                     collate_fn=transpose_batch,
+144                                     shuffle=True)
@@ -491,8 +492,8 @@ this give logits of the the next token

-
147        if configs.glu_variant == 'GLU':
-148            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
+
148        if configs.glu_variant == 'GLU':
+149            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
@@ -505,8 +506,8 @@ this give logits of the the next token

-
151        elif configs.glu_variant == 'Bilinear':
-152            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
+
152        elif configs.glu_variant == 'Bilinear':
+153            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
@@ -519,8 +520,8 @@ this give logits of the the next token

-
155        elif configs.glu_variant == 'ReGLU':
-156            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
+
156        elif configs.glu_variant == 'ReGLU':
+157            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
@@ -533,8 +534,8 @@ this give logits of the the next token

-
159        elif configs.glu_variant == 'GEGLU':
-160            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
+
160        elif configs.glu_variant == 'GEGLU':
+161            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
@@ -547,8 +548,8 @@ this give logits of the the next token

where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

-
164        elif configs.glu_variant == 'SwiGLU':
-165            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
+
165        elif configs.glu_variant == 'SwiGLU':
+166            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
@@ -561,8 +562,8 @@ where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

-
168        elif configs.glu_variant == 'ReLU':
-169            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
+
169        elif configs.glu_variant == 'ReLU':
+170            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
@@ -575,10 +576,10 @@ where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

-
172        elif configs.glu_variant == 'GELU':
-173            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
-174        else:
-175            raise ValueError(f'Unknown variant {configs.glu_variant}')
+
173        elif configs.glu_variant == 'GELU':
+174            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
+175        else:
+176            raise ValueError(f'Unknown variant {configs.glu_variant}')
@@ -589,7 +590,7 @@ where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

Number of different characters

-
178        n_chars = len(self.dataset.stoi)
+
179        n_chars = len(self.dataset.stoi)
@@ -600,7 +601,7 @@ where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

Initialize Multi-Head Attention module

-
181        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
+
182        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
@@ -611,8 +612,8 @@ where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

Initialize the Transformer Block

-
183        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
-184                                             feed_forward=ffn, dropout_prob=configs.dropout)
+
184        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
+185                                             feed_forward=ffn, dropout_prob=configs.dropout)
@@ -627,9 +628,9 @@ where $\text{Swish}_\beta(x) = x \sigma(\beta x)$

a linear layer to generate logits.

-
190        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
-191                                         Encoder(transformer_layer, configs.n_layers),
-192                                         nn.Linear(configs.d_model, n_chars))
+
191        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
+192                                         Encoder(transformer_layer, configs.n_layers),
+193                                         nn.Linear(configs.d_model, n_chars))
@@ -640,7 +641,7 @@ a linear layer to generate logits.

Move the model to the current device

-
195        self.model.to(self.device)
+
196        self.model.to(self.device)
@@ -651,7 +652,7 @@ a linear layer to generate logits.

Initialize Noam optimizer

-
198        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
+
199        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
@@ -662,7 +663,7 @@ a linear layer to generate logits.

Cross-entropy loss

-
201        self.loss_func = nn.CrossEntropyLoss()
+
202        self.loss_func = nn.CrossEntropyLoss()
@@ -674,7 +675,7 @@ a linear layer to generate logits.

*note that our dataset definition repeats the data seq_len times in a single epoch

-
204        self.epochs = configs.epochs
+
205        self.epochs = configs.epochs
@@ -685,7 +686,7 @@ a linear layer to generate logits.

Gradient clipping norm

-
206        self.grad_norm_clip = configs.grad_norm_clip
+
207        self.grad_norm_clip = configs.grad_norm_clip
@@ -696,7 +697,7 @@ a linear layer to generate logits.

Set tracker configurations

-
209        tracker.set_scalar("loss.*", True)
+
210        tracker.set_scalar("loss.*", True)
@@ -707,7 +708,7 @@ a linear layer to generate logits.

Sampling function to generate samples periodically while training

-
211    def sample(self):
+
212    def sample(self):
@@ -718,7 +719,7 @@ a linear layer to generate logits.

Starting prompt

-
217        prompt = 'It is'
+
218        prompt = 'It is'
@@ -729,7 +730,7 @@ a linear layer to generate logits.

Collect output for printing

-
219        log = [(prompt, Text.subtle)]
+
220        log = [(prompt, Text.subtle)]
@@ -740,7 +741,7 @@ a linear layer to generate logits.

Sample 25 tokens

-
221        for i in monit.iterate('Sample', 25):
+
222        for i in monit.iterate('Sample', 25):
@@ -751,8 +752,8 @@ a linear layer to generate logits.

Tokenize the prompt

-
223            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
-224            data = data.to(self.device)
+
224            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
+225            data = data.to(self.device)
@@ -763,7 +764,7 @@ a linear layer to generate logits.

Get the model output

-
226            output = self.model(data)
+
227            output = self.model(data)
@@ -774,7 +775,7 @@ a linear layer to generate logits.

Get the model prediction (greedy)

-
228            output = output.argmax(dim=-1).squeeze()
+
229            output = output.argmax(dim=-1).squeeze()
@@ -785,7 +786,7 @@ a linear layer to generate logits.

Add the prediction to prompt

-
230            prompt += self.dataset.itos[output[-1].item()]
+
231            prompt += self.dataset.itos[output[-1].item()]
@@ -796,7 +797,7 @@ a linear layer to generate logits.

Add the prediction for logging

-
232            log += [(self.dataset.itos[output[-1].item()], Text.value)]
+
233            log += [(self.dataset.itos[output[-1].item()], Text.value)]
@@ -807,7 +808,7 @@ a linear layer to generate logits.

Print the sampled output

-
235        logger.log(log)
+
236        logger.log(log)
@@ -818,7 +819,7 @@ a linear layer to generate logits.

Train the model

-
237    def train(self):
+
238    def train(self):
@@ -829,7 +830,7 @@ a linear layer to generate logits.

Loop for the given number of epochs

-
243        for _ in monit.loop(self.epochs):
+
244        for _ in monit.loop(self.epochs):
@@ -840,7 +841,7 @@ a linear layer to generate logits.

Iterate over the minibatches

-
245            for i, batch in monit.enum('Train', self.dataloader):
+
246            for i, batch in monit.enum('Train', self.dataloader):
@@ -851,7 +852,7 @@ a linear layer to generate logits.

Move data to the device

-
247                data, target = batch[0].to(self.device), batch[1].to(self.device)
+
248                data, target = batch[0].to(self.device), batch[1].to(self.device)
@@ -862,7 +863,7 @@ a linear layer to generate logits.

Set tracker step, as the number of characters trained on

-
250                tracker.add_global_step(data.shape[0] * data.shape[1])
+
251                tracker.add_global_step(data.shape[0] * data.shape[1])
@@ -873,7 +874,7 @@ a linear layer to generate logits.

Set model state to training

-
253                self.model.train()
+
254                self.model.train()
@@ -884,7 +885,7 @@ a linear layer to generate logits.

Evaluate the model

-
255                output = self.model(data)
+
256                output = self.model(data)
@@ -895,7 +896,7 @@ a linear layer to generate logits.

Calculate loss

-
258                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
+
259                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
@@ -906,7 +907,7 @@ a linear layer to generate logits.

Log the loss

-
260                tracker.add("loss.train", loss)
+
261                tracker.add("loss.train", loss)
@@ -917,7 +918,7 @@ a linear layer to generate logits.

Calculate gradients

-
263                loss.backward()
+
264                loss.backward()
@@ -928,7 +929,7 @@ a linear layer to generate logits.

Clip gradients

-
265                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
+
266                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
@@ -939,7 +940,7 @@ a linear layer to generate logits.

Take optimizer step

-
267                self.optimizer.step()
+
268                self.optimizer.step()
@@ -950,8 +951,8 @@ a linear layer to generate logits.

Log the model parameters and gradients

-
269                if (i + 1) % 100 == 0:
-270                    tracker.add('model', self.model)
+
270                if (i + 1) % 100 == 0:
+271                    tracker.add('model', self.model)
@@ -962,7 +963,7 @@ a linear layer to generate logits.

Clear the gradients

-
272                self.optimizer.zero_grad()
+
273                self.optimizer.zero_grad()
@@ -973,10 +974,10 @@ a linear layer to generate logits.

Generate a sample

-
275                if (i + 1) % 100 == 0:
-276                    self.model.eval()
-277                    with torch.no_grad():
-278                        self.sample()
+
276                if (i + 1) % 100 == 0:
+277                    self.model.eval()
+278                    with torch.no_grad():
+279                        self.sample()
@@ -987,8 +988,8 @@ a linear layer to generate logits.

Save the tracked metrics

-
281                if (i + 1) % 10 == 0:
-282                    tracker.save()
+
282                if (i + 1) % 10 == 0:
+283                    tracker.save()
@@ -999,7 +1000,7 @@ a linear layer to generate logits.

Save the model

-
285            experiment.save_checkpoint()
+
286            experiment.save_checkpoint()
@@ -1010,7 +1011,7 @@ a linear layer to generate logits.

-
288def main():
+
289def main():
@@ -1021,7 +1022,7 @@ a linear layer to generate logits.

Create experiment

-
290    experiment.create(name="glu_variants")
+
291    experiment.create(name="glu_variants")
@@ -1032,7 +1033,7 @@ a linear layer to generate logits.

Create configs

-
292    configs = Configs()
+
293    configs = Configs()
@@ -1043,7 +1044,7 @@ a linear layer to generate logits.

Load configurations

-
294    experiment.configs(dataclasses.asdict(configs))
+
295    experiment.configs(dataclasses.asdict(configs))
@@ -1054,7 +1055,7 @@ a linear layer to generate logits.

Create trainer

-
297    trainer = Trainer(configs)
+
298    trainer = Trainer(configs)
@@ -1065,7 +1066,7 @@ a linear layer to generate logits.

Set models for training and loading

-
299    experiment.add_pytorch_models({'model': trainer.model})
+
300    experiment.add_pytorch_models({'model': trainer.model})
@@ -1076,7 +1077,7 @@ a linear layer to generate logits.

Start the experiment

-
302    with experiment.start():
+
303    with experiment.start():
@@ -1087,11 +1088,11 @@ a linear layer to generate logits.

Train the model

-
304        trainer.train()
-305
+                
305        trainer.train()
 306
-307if __name__ == '__main__':
-308    main()
+307 +308if __name__ == '__main__': +309 main()
diff --git a/docs/transformers/gpt/index.html b/docs/transformers/gpt/index.html index c9e004c7..e576b3e6 100644 --- a/docs/transformers/gpt/index.html +++ b/docs/transformers/gpt/index.html @@ -168,7 +168,7 @@ a final linear layer that gives token logits.

-
70    def __call__(self, x: torch.Tensor):
+
70    def forward(self, x: torch.Tensor):
diff --git a/docs/transformers/knn/train_model.html b/docs/transformers/knn/train_model.html index 9e552bd5..14c0dddb 100644 --- a/docs/transformers/knn/train_model.html +++ b/docs/transformers/knn/train_model.html @@ -200,7 +200,7 @@ this give logits of the the next token

-
52    def __call__(self, src: torch.Tensor):
+
52    def forward(self, src: torch.Tensor):
diff --git a/docs/transformers/label_smoothing_loss.html b/docs/transformers/label_smoothing_loss.html index fb2ac17b..0bd0d15e 100644 --- a/docs/transformers/label_smoothing_loss.html +++ b/docs/transformers/label_smoothing_loss.html @@ -119,7 +119,7 @@
-
29    def __call__(self, x: torch.Tensor, target: torch.Tensor):
+                
29    def forward(self, x: torch.Tensor, target: torch.Tensor):
 30        assert x.shape[1] == self.size
 31        true_dist = x.clone()
 32        true_dist.fill_(self.smoothing / (self.size - 2))
diff --git a/docs/transformers/mha.html b/docs/transformers/mha.html
index 6c19e8f1..9f3f1478 100644
--- a/docs/transformers/mha.html
+++ b/docs/transformers/mha.html
@@ -155,7 +155,7 @@ This is used to transform key, query, and 
             
-
45    def __call__(self, x: torch.Tensor):
+
45    def forward(self, x: torch.Tensor):
@@ -377,7 +377,7 @@ They have shape [seq_len, batch_size, d_model].

query at position i has access to key-value at position j.

-
121    def __call__(self, *,
+                
121    def forward(self, *,
 122                 query: torch.Tensor,
 123                 key: torch.Tensor,
 124                 value: torch.Tensor,
diff --git a/docs/transformers/models.html b/docs/transformers/models.html
index 5ea0ba11..15f793cb 100644
--- a/docs/transformers/models.html
+++ b/docs/transformers/models.html
@@ -122,7 +122,7 @@
                 
             
-
36    def __call__(self, x: torch.Tensor):
+                
36    def forward(self, x: torch.Tensor):
 37        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
 38        return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -163,7 +163,7 @@
-
54    def __call__(self, x: torch.Tensor):
+                
54    def forward(self, x: torch.Tensor):
 55        pe = self.positional_encodings[:x.shape[0]]
 56        return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -251,7 +251,7 @@ We found a detailed discussion about this in the paper
-
103    def __call__(self, *,
+                
103    def forward(self, *,
 104                 x: torch.Tensor,
 105                 mask: torch.Tensor,
 106                 src: torch.Tensor = None,
@@ -439,7 +439,7 @@ encoder outputs

-
153    def __call__(self, x: torch.Tensor, mask: torch.Tensor):
+
153    def forward(self, x: torch.Tensor, mask: torch.Tensor):
@@ -520,7 +520,7 @@ encoder outputs

-
175    def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+
175    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
@@ -582,7 +582,7 @@ You don’t need this if you are using nn.CrossEntropyLoss.

-
197    def __call__(self, x):
+                
197    def forward(self, x):
 198        return self.projection(x)
@@ -638,7 +638,7 @@ Initialize parameters with Glorot / fan_avg.

-
222    def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+
222    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
diff --git a/docs/transformers/positional_encoding.html b/docs/transformers/positional_encoding.html index 9f1996dc..584ee620 100644 --- a/docs/transformers/positional_encoding.html +++ b/docs/transformers/positional_encoding.html @@ -127,7 +127,7 @@ PE_{p,2i + 1} &= cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)
-
39    def __call__(self, x: torch.Tensor):
+                
39    def forward(self, x: torch.Tensor):
 40        pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
 41        x = x + pe
 42        x = self.dropout(x)
diff --git a/docs/transformers/switch/experiment.html b/docs/transformers/switch/experiment.html
index fa82fc68..ba225c83 100644
--- a/docs/transformers/switch/experiment.html
+++ b/docs/transformers/switch/experiment.html
@@ -151,7 +151,7 @@
                 
             
-
37    def __call__(self, x: torch.Tensor):
+
37    def forward(self, x: torch.Tensor):
diff --git a/docs/transformers/switch/index.html b/docs/transformers/switch/index.html index bb4d2ab6..23e6481b 100644 --- a/docs/transformers/switch/index.html +++ b/docs/transformers/switch/index.html @@ -192,7 +192,7 @@ discusses dropping tokens when routing is not balanced.

-
83    def __call__(self, x: torch.Tensor):
+
83    def forward(self, x: torch.Tensor):
@@ -525,7 +525,7 @@ with handling extra outputs of switch feedforward module.

-
192    def __call__(self, *,
+                
192    def forward(self, *,
 193                 x: torch.Tensor,
 194                 mask: torch.Tensor):
@@ -651,7 +651,7 @@ with handling extra outputs of switch feedforward module.

-
224    def __call__(self, x: torch.Tensor, mask: torch.Tensor):
+
224    def forward(self, x: torch.Tensor, mask: torch.Tensor):
diff --git a/labml_nn/transformers/feed_forward.py b/labml_nn/transformers/feed_forward.py index d700d778..fce34ed5 100644 --- a/labml_nn/transformers/feed_forward.py +++ b/labml_nn/transformers/feed_forward.py @@ -78,7 +78,7 @@ class FeedForward(Module): # be multiplied by the gate, parameterized by weight $V$ and bias $c$ self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): # $f(x W_1 + b_1)$ g = self.activation(self.layer1(x)) # If gated, $f(x W_1 + b_1) \otimes (x V + b) $ diff --git a/labml_nn/transformers/feedback/__init__.py b/labml_nn/transformers/feedback/__init__.py index 29bef0ba..740a4b62 100644 --- a/labml_nn/transformers/feedback/__init__.py +++ b/labml_nn/transformers/feedback/__init__.py @@ -155,7 +155,7 @@ class FeedbackAttention(Module): # $A_j$ return ac + bd - def __call__(self, *, + def forward(self, *, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): @@ -226,7 +226,7 @@ class FeedbackTransformerLayer(Module): self.norm_self_attn = nn.LayerNorm([d_model]) self.norm_ff = nn.LayerNorm([d_model]) - def __call__(self, *, + def forward(self, *, x: torch.Tensor, key: Optional[torch.Tensor], value: Optional[torch.Tensor]): @@ -272,7 +272,7 @@ class FeedbackTransformer(Module): # Softmax for weights before taking the weighted sum self.softmax = nn.Softmax(0) - def __call__(self, x_seq: torch.Tensor): + def forward(self, x_seq: torch.Tensor): """ * `x_seq` is the input with shape `[seq_len, batch_size, d_model]` """ @@ -470,7 +470,7 @@ class FeedbackTransformerKV(Module): # Memory for stacked values self.mem_value = Stack(512) - def __call__(self, x_seq: torch.Tensor): + def forward(self, x_seq: torch.Tensor): """ * `x_seq` is the input with shape `[seq_len, batch_size, d_model]` """ diff --git a/labml_nn/transformers/feedback/experiment.py b/labml_nn/transformers/feedback/experiment.py index 791869c2..fe068372 100644 --- a/labml_nn/transformers/feedback/experiment.py +++ b/labml_nn/transformers/feedback/experiment.py @@ -41,7 +41,7 @@ class AutoregressiveModel(Module): self.transformer = transformer self.generator = nn.Linear(d_model, n_vocab) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): # Embed the tokens x = self.src_embed(x) # Run it through the the transformer diff --git a/labml_nn/transformers/glu_variants/experiment.py b/labml_nn/transformers/glu_variants/experiment.py index 89c21b70..de2768bc 100644 --- a/labml_nn/transformers/glu_variants/experiment.py +++ b/labml_nn/transformers/glu_variants/experiment.py @@ -41,7 +41,7 @@ class AutoregressiveModel(Module): # This will be initialized on the first call self.src_mask = None - def __call__(self, src: torch.Tensor): + def forward(self, src: torch.Tensor): # Create subsequent mask, so that the transformer can only pay attention to past tokens. if self.src_mask is None or self.src_mask.size(0) != len(src): self.src_mask = subsequent_mask(len(src)).to(src.device) diff --git a/labml_nn/transformers/glu_variants/simple.py b/labml_nn/transformers/glu_variants/simple.py index eaf3fac2..7366b1d8 100644 --- a/labml_nn/transformers/glu_variants/simple.py +++ b/labml_nn/transformers/glu_variants/simple.py @@ -20,6 +20,7 @@ We decided to write a simpler implementation to make it easier readers who are n import dataclasses import torch +from labml_helpers.module import Module from torch import nn from torch.utils.data import Dataset, DataLoader @@ -34,7 +35,7 @@ from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, Trans from labml_nn.transformers.utils import subsequent_mask -class AutoregressiveModel(nn.Module): +class AutoregressiveModel(Module): """ ## Auto regressive model """ @@ -51,7 +52,7 @@ class AutoregressiveModel(nn.Module): # This will be initialized on the first call self.src_mask = None - def __call__(self, src: torch.Tensor): + def forward(self, src: torch.Tensor): # Create subsequent mask, so that the transformer can only pay attention to past tokens. if self.src_mask is None or self.src_mask.size(0) != len(src): self.src_mask = subsequent_mask(len(src)).to(src.device) diff --git a/labml_nn/transformers/gpt/__init__.py b/labml_nn/transformers/gpt/__init__.py index fc715b81..cc083055 100644 --- a/labml_nn/transformers/gpt/__init__.py +++ b/labml_nn/transformers/gpt/__init__.py @@ -67,7 +67,7 @@ class GPT(Module): # The mask will be initialized on the first call self.mask = None - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): # Create subsequent mask if mask is not initialized # or if the size of the mask is different if self.mask is None or self.mask.size(0) != len(x): diff --git a/labml_nn/transformers/knn/train_model.py b/labml_nn/transformers/knn/train_model.py index 1cb1b693..83ac627d 100644 --- a/labml_nn/transformers/knn/train_model.py +++ b/labml_nn/transformers/knn/train_model.py @@ -49,7 +49,7 @@ class AutoregressiveModel(Module): """ return self.encoder.layers[-1].ff_input - def __call__(self, src: torch.Tensor): + def forward(self, src: torch.Tensor): # Create subsequent mask, so that the transformer can only pay attention to past tokens. if self.src_mask is None or self.src_mask.size(0) != len(src): self.src_mask = subsequent_mask(len(src)).to(src.device) diff --git a/labml_nn/transformers/label_smoothing_loss.py b/labml_nn/transformers/label_smoothing_loss.py index 20aa109a..38296b4c 100644 --- a/labml_nn/transformers/label_smoothing_loss.py +++ b/labml_nn/transformers/label_smoothing_loss.py @@ -26,7 +26,7 @@ class LabelSmoothingLoss(Module): self.size = size self.true_dist = None - def __call__(self, x: torch.Tensor, target: torch.Tensor): + def forward(self, x: torch.Tensor, target: torch.Tensor): assert x.shape[1] == self.size true_dist = x.clone() true_dist.fill_(self.smoothing / (self.size - 2)) diff --git a/labml_nn/transformers/mha.py b/labml_nn/transformers/mha.py index 6e897181..ba42debb 100644 --- a/labml_nn/transformers/mha.py +++ b/labml_nn/transformers/mha.py @@ -42,7 +42,7 @@ class PrepareForMultiHeadAttention(Module): # Number of dimensions in vectors in each head self.d_k = d_k - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`. # We apply the linear transformation to the last dimension and split that into # the heads. @@ -118,7 +118,7 @@ class MultiHeadAttention(Module): # Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$ return torch.einsum('ibhd,jbhd->ijbh', query, key) - def __call__(self, *, + def forward(self, *, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, diff --git a/labml_nn/transformers/models.py b/labml_nn/transformers/models.py index 3c89b099..a9848cff 100644 --- a/labml_nn/transformers/models.py +++ b/labml_nn/transformers/models.py @@ -33,7 +33,7 @@ class EmbeddingsWithPositionalEncoding(Module): self.d_model = d_model self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len)) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): pe = self.positional_encodings[:x.shape[0]].requires_grad_(False) return self.linear(x) * math.sqrt(self.d_model) + pe @@ -51,7 +51,7 @@ class EmbeddingsWithLearnedPositionalEncoding(Module): self.d_model = d_model self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): pe = self.positional_encodings[:x.shape[0]] return self.linear(x) * math.sqrt(self.d_model) + pe @@ -100,7 +100,7 @@ class TransformerLayer(Module): # Whether to save input to the feed forward layer self.is_save_ff_input = False - def __call__(self, *, + def forward(self, *, x: torch.Tensor, mask: torch.Tensor, src: torch.Tensor = None, @@ -150,7 +150,7 @@ class Encoder(Module): # Final normalization layer self.norm = nn.LayerNorm([layer.size]) - def __call__(self, x: torch.Tensor, mask: torch.Tensor): + def forward(self, x: torch.Tensor, mask: torch.Tensor): # Run through each transformer layer for layer in self.layers: x = layer(x=x, mask=mask) @@ -172,7 +172,7 @@ class Decoder(Module): # Final normalization layer self.norm = nn.LayerNorm([layer.size]) - def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): + def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): # Run through each transformer layer for layer in self.layers: x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask) @@ -194,7 +194,7 @@ class Generator(Module): super().__init__() self.projection = nn.Linear(d_model, n_vocab) - def __call__(self, x): + def forward(self, x): return self.projection(x) @@ -219,7 +219,7 @@ class EncoderDecoder(Module): if p.dim() > 1: nn.init.xavier_uniform_(p) - def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): + def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): # Run the source through encoder enc = self.encode(src, src_mask) # Run encodings and targets through decoder diff --git a/labml_nn/transformers/positional_encoding.py b/labml_nn/transformers/positional_encoding.py index 74c56f9f..f974b096 100644 --- a/labml_nn/transformers/positional_encoding.py +++ b/labml_nn/transformers/positional_encoding.py @@ -36,7 +36,7 @@ class PositionalEncoding(Module): self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False) x = x + pe x = self.dropout(x) diff --git a/labml_nn/transformers/switch/__init__.py b/labml_nn/transformers/switch/__init__.py index 1fb31db8..119e81b3 100644 --- a/labml_nn/transformers/switch/__init__.py +++ b/labml_nn/transformers/switch/__init__.py @@ -80,7 +80,7 @@ class SwitchFeedForward(Module): self.switch = nn.Linear(d_model, n_experts) self.softmax = nn.Softmax(dim=-1) - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ * `x` is the input to the switching module with shape `[seq_len, batch_size, d_model]` """ @@ -189,7 +189,7 @@ class SwitchTransformerLayer(Module): self.norm_self_attn = nn.LayerNorm([d_model]) self.norm_ff = nn.LayerNorm([d_model]) - def __call__(self, *, + def forward(self, *, x: torch.Tensor, mask: torch.Tensor): # Normalize the vectors before doing self attention @@ -221,7 +221,7 @@ class SwitchTransformer(Module): # Final normalization layer self.norm = nn.LayerNorm([layer.size]) - def __call__(self, x: torch.Tensor, mask: torch.Tensor): + def forward(self, x: torch.Tensor, mask: torch.Tensor): # Run through each transformer layer counts, route_prob, n_dropped = [], [], [] for layer in self.layers: diff --git a/labml_nn/transformers/switch/experiment.py b/labml_nn/transformers/switch/experiment.py index a426a1ec..6956ec8b 100644 --- a/labml_nn/transformers/switch/experiment.py +++ b/labml_nn/transformers/switch/experiment.py @@ -34,7 +34,7 @@ class AutoregressiveModel(Module): self.generator = nn.Linear(d_model, n_vocab) self.mask = None - def __call__(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): # Initialize the subsequent mask if self.mask is None or self.mask.size(0) != len(x): from labml_nn.transformers.utils import subsequent_mask