diff --git a/Makefile b/Makefile index 781da712..8848a160 100644 --- a/Makefile +++ b/Makefile @@ -38,8 +38,12 @@ docs-zh: ## Chinese Translation cd labml_nn; pylit --translate zh --translate_cache ../translate_cache --remove_empty_sections --title_md -t ../../../pylit/templates/nn -d ../docs/zh -w * docs: ## Render annotated HTML + mv docs/zh docs_zh + mv docs/si docs_si find ./docs/ -name "*.html" -type f -delete find ./docs/ -name "*.svg" -type f -delete + mv docs_si docs/si + mv docs_zh docs/zh python utils/sitemap.py python utils/diagrams.py cd labml_nn; pylit --remove_empty_sections --title_md -t ../../../pylit/templates/nn -d ../docs -w * diff --git a/docs/activations/fta/experiment.html b/docs/activations/fta/experiment.html index 398fd7c7..85316cb4 100644 --- a/docs/activations/fta/experiment.html +++ b/docs/activations/fta/experiment.html @@ -1,5 +1,5 @@ - + @@ -76,24 +76,24 @@ #

Fuzzy Tiling Activation Experiment

-

Open In Colab Open In Comet

+

Open In Colab

Here we train a transformer that uses Fuzzy Tiling Activation in the Feed-Forward Network. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.

However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.

-
22import copy
-23
-24import torch
-25import torch.nn as nn
-26
-27from labml import experiment
-28from labml.configs import option
-29from labml_helpers.module import Module
-30from labml_nn.activations.fta import FTA
-31from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-32from labml_nn.transformers import MultiHeadAttention, TransformerLayer
-33from labml_nn.transformers.utils import subsequent_mask
+
21import copy
+22
+23import torch
+24import torch.nn as nn
+25
+26from labml import experiment
+27from labml.configs import option
+28from labml_helpers.module import Module
+29from labml_nn.activations.fta import FTA
+30from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+31from labml_nn.transformers import MultiHeadAttention, TransformerLayer
+32from labml_nn.transformers.utils import subsequent_mask
@@ -105,7 +105,7 @@
-
36class FeedForwardFTA(nn.Module):
+
35class FeedForwardFTA(nn.Module):
@@ -124,9 +124,9 @@
-
41    def __init__(self, d_model: int, d_ff: int,
-42                 activation: FTA,
-43                 dropout: float = 0.1):
+
40    def __init__(self, d_model: int, d_ff: int,
+41                 activation: FTA,
+42                 dropout: float = 0.1):
@@ -137,7 +137,7 @@
-
50        super().__init__()
+
49        super().__init__()
@@ -149,7 +149,7 @@
-
52        self.layer1 = nn.Linear(d_model, d_ff)
+
51        self.layer1 = nn.Linear(d_model, d_ff)
@@ -161,7 +161,7 @@
-
54        self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)
+
53        self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)
@@ -173,7 +173,7 @@
-
56        self.dropout = nn.Dropout(dropout)
+
55        self.dropout = nn.Dropout(dropout)
@@ -185,7 +185,7 @@
-
58        self.activation = activation
+
57        self.activation = activation
@@ -196,7 +196,7 @@
-
60    def forward(self, x: torch.Tensor):
+
59    def forward(self, x: torch.Tensor):
@@ -208,7 +208,7 @@
-
62        x = self.activation(self.layer1(x))
+
61        x = self.activation(self.layer1(x))
@@ -220,7 +220,7 @@
-
64        x = self.dropout(x)
+
63        x = self.dropout(x)
@@ -232,7 +232,7 @@
-
66        return self.layer2(x)
+
65        return self.layer2(x)
@@ -245,7 +245,7 @@
-
69class AutoregressiveTransformer(Module):
+
68class AutoregressiveTransformer(Module):
@@ -265,7 +265,7 @@
-
77    def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
+
76    def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
@@ -276,7 +276,7 @@
-
84        super().__init__()
+
83        super().__init__()
@@ -289,7 +289,7 @@
-
86        self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
+
85        self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
@@ -301,7 +301,7 @@
-
89        self.emb = nn.Embedding(n_tokens, d_model)
+
88        self.emb = nn.Embedding(n_tokens, d_model)
@@ -313,7 +313,7 @@
-
91        self.readout = nn.Linear(d_model, n_tokens)
+
90        self.readout = nn.Linear(d_model, n_tokens)
@@ -325,7 +325,7 @@
-
94        self.mask = None
+
93        self.mask = None
@@ -339,7 +339,7 @@
-
96    def forward(self, x: torch.Tensor):
+
95    def forward(self, x: torch.Tensor):
@@ -351,7 +351,7 @@
-
101        if self.mask is None or self.mask.size(0) != len(x):
+
100        if self.mask is None or self.mask.size(0) != len(x):
@@ -363,7 +363,7 @@
-
103            self.mask = subsequent_mask(len(x)).to(x.device)
+
102            self.mask = subsequent_mask(len(x)).to(x.device)
@@ -375,7 +375,7 @@
-
106        x = self.emb(x)
+
105        x = self.emb(x)
@@ -387,8 +387,8 @@
-
108        for layer in self.transformer_layers:
-109            x = layer(x=x, mask=self.mask)
+
107        for layer in self.transformer_layers:
+108            x = layer(x=x, mask=self.mask)
@@ -400,7 +400,7 @@
-
111        x = self.readout(x)
+
110        x = self.readout(x)
@@ -412,7 +412,7 @@
-
114        return x, None
+
113        return x, None
@@ -426,7 +426,7 @@
-
117class Configs(NLPAutoRegressionConfigs):
+
116class Configs(NLPAutoRegressionConfigs):
@@ -438,7 +438,7 @@
-
126    model: AutoregressiveTransformer
+
125    model: AutoregressiveTransformer
@@ -450,7 +450,7 @@
-
129    n_layers: int = 4
+
128    n_layers: int = 4
@@ -462,8 +462,8 @@
-
132    deep_norm_alpha: float
-133    deep_norm_beta: float
+
131    deep_norm_alpha: float
+132    deep_norm_beta: float
@@ -475,7 +475,7 @@
-
136    n_heads: int = 4
+
135    n_heads: int = 4
@@ -487,7 +487,7 @@
-
138    d_model: int = 256
+
137    d_model: int = 256
@@ -499,7 +499,7 @@
-
140    d_k: int = 16
+
139    d_k: int = 16
@@ -511,7 +511,7 @@
-
142    d_ff: int = 256
+
141    d_ff: int = 256
@@ -523,10 +523,10 @@
-
145    fta_lower_limit: float = -1.
-146    fta_upper_limit: float = +1.
-147    fta_delta: float = 0.2
-148    fta_eta: float = 0.05
+
144    fta_lower_limit: float = -1.
+145    fta_upper_limit: float = +1.
+146    fta_delta: float = 0.2
+147    fta_eta: float = 0.05
@@ -538,8 +538,8 @@
-
151@option(Configs.model)
-152def _model(c: Configs):
+
150@option(Configs.model)
+151def _model(c: Configs):
@@ -551,7 +551,7 @@
-
158    fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)
+
157    fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)
@@ -565,15 +565,15 @@
-
162    m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
-163                                  TransformerLayer(d_model=c.d_model,
-164                                                   feed_forward=FeedForwardFTA(d_model=c.d_model,
-165                                                                               d_ff=c.d_ff,
-166                                                                               activation=fta,
-167                                                                               dropout=0.1),
-168                                                   self_attn=MultiHeadAttention(c.n_heads, c.d_model,
-169                                                                                dropout_prob=0.0),
-170                                                   dropout_prob=0.0))
+
161    m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
+162                                  TransformerLayer(d_model=c.d_model,
+163                                                   feed_forward=FeedForwardFTA(d_model=c.d_model,
+164                                                                               d_ff=c.d_ff,
+165                                                                               activation=fta,
+166                                                                               dropout=0.1),
+167                                                   self_attn=MultiHeadAttention(c.n_heads, c.d_model,
+168                                                                                dropout_prob=0.0),
+169                                                   dropout_prob=0.0))
@@ -585,7 +585,7 @@
-
173    return m.to(c.device)
+
172    return m.to(c.device)
@@ -597,7 +597,7 @@
-
176def main():
+
175def main():
@@ -609,7 +609,7 @@
-
181    experiment.create(name="fta", writers={'screen',  'comet', 'labml'})
+
180    experiment.create(name="fta", writers={'screen', 'labml'})
@@ -621,7 +621,7 @@
-
183    conf = Configs()
+
182    conf = Configs()
@@ -633,7 +633,7 @@
-
185    experiment.configs(conf, {
+
184    experiment.configs(conf, {
@@ -645,7 +645,7 @@
-
187        'tokenizer': 'character',
+
186        'tokenizer': 'character',
@@ -657,7 +657,7 @@
-
189        'prompt_separator': '',
+
188        'prompt_separator': '',
@@ -669,7 +669,7 @@
-
191        'prompt': 'It is ',
+
190        'prompt': 'It is ',
@@ -681,7 +681,7 @@
-
193        'text': 'tiny_shakespeare',
+
192        'text': 'tiny_shakespeare',
@@ -693,7 +693,7 @@
-
196        'seq_len': 256,
+
195        'seq_len': 256,
@@ -705,7 +705,7 @@
-
198        'epochs': 32,
+
197        'epochs': 32,
@@ -717,7 +717,7 @@
-
200        'batch_size': 16,
+
199        'batch_size': 16,
@@ -729,7 +729,7 @@
-
202        'inner_iterations': 10,
+
201        'inner_iterations': 10,
@@ -741,9 +741,9 @@
-
205        'optimizer.optimizer': 'Adam',
-206        'optimizer.learning_rate': 3e-4,
-207    })
+
204        'optimizer.optimizer': 'Adam',
+205        'optimizer.learning_rate': 3e-4,
+206    })
@@ -755,7 +755,7 @@
-
210    experiment.add_pytorch_models({'model': conf.model})
+
209    experiment.add_pytorch_models({'model': conf.model})
@@ -767,7 +767,7 @@
-
213    with experiment.start():
+
212    with experiment.start():
@@ -779,7 +779,7 @@
-
215        conf.run()
+
214        conf.run()
@@ -791,8 +791,8 @@
-
219if __name__ == '__main__':
-220    main()
+
218if __name__ == '__main__':
+219    main()

Fuzzy Tiling Activations (FTA)

-

Open In Colab Open In Comet

+

Open In Colab

This is a PyTorch implementation/tutorial of Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online.

Fuzzy tiling activations are a form of sparse activations based on binning.

Binning is classification of a scalar value into a bin based on intervals. One problem with binning is that it gives zero gradients for most values (except at the boundary of bins). The other is that binning loses precision if the bin intervals are large.

@@ -99,8 +99,8 @@
-
62import torch
-63from torch import nn
+
61import torch
+62from torch import nn
@@ -112,7 +112,7 @@
-
66class FTA(nn.Module):
+
65class FTA(nn.Module):
@@ -131,7 +131,7 @@
-
71    def __init__(self, lower_limit: float, upper_limit: float, delta: float, eta: float):
+
70    def __init__(self, lower_limit: float, upper_limit: float, delta: float, eta: float):
@@ -142,7 +142,7 @@
-
78        super().__init__()
+
77        super().__init__()
@@ -154,7 +154,7 @@
-
81        self.c = nn.Parameter(torch.arange(lower_limit, upper_limit, delta), requires_grad=False)
+
80        self.c = nn.Parameter(torch.arange(lower_limit, upper_limit, delta), requires_grad=False)
@@ -166,7 +166,7 @@
-
83        self.expansion_factor = len(self.c)
+
82        self.expansion_factor = len(self.c)
@@ -178,7 +178,7 @@
-
85        self.delta = delta
+
84        self.delta = delta
@@ -190,7 +190,7 @@
-
87        self.eta = eta
+
86        self.eta = eta
@@ -203,7 +203,7 @@
-
89    def fuzzy_i_plus(self, x: torch.Tensor):
+
88    def fuzzy_i_plus(self, x: torch.Tensor):
@@ -214,7 +214,7 @@
-
95        return (x <= self.eta) * x + (x > self.eta)
+
94        return (x <= self.eta) * x + (x > self.eta)
@@ -225,7 +225,7 @@
-
97    def forward(self, z: torch.Tensor):
+
96    def forward(self, z: torch.Tensor):
@@ -237,7 +237,7 @@
-
100        z = z.view(*z.shape, 1)
+
99        z = z.view(*z.shape, 1)
@@ -249,7 +249,7 @@
-
103        z = 1. - self.fuzzy_i_plus(torch.clip(self.c - z, min=0.) + torch.clip(z - self.delta - self.c, min=0.))
+
102        z = 1. - self.fuzzy_i_plus(torch.clip(self.c - z, min=0.) + torch.clip(z - self.delta - self.c, min=0.))
@@ -261,7 +261,7 @@
-
107        return z.view(*z.shape[:-2], -1)
+
106        return z.view(*z.shape[:-2], -1)
@@ -273,7 +273,7 @@
-
110def _test():
+
109def _test():
@@ -284,7 +284,7 @@
-
114    from labml.logger import inspect
+
113    from labml.logger import inspect
@@ -296,7 +296,7 @@
-
117    a = FTA(-10, 10, 2., 0.5)
+
116    a = FTA(-10, 10, 2., 0.5)
@@ -308,7 +308,7 @@
-
119    inspect(a.c)
+
118    inspect(a.c)
@@ -320,7 +320,7 @@
-
121    inspect(a.expansion_factor)
+
120    inspect(a.expansion_factor)
@@ -332,7 +332,7 @@
-
124    z = torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9., 10., 11.])
+
123    z = torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9., 10., 11.])
@@ -344,7 +344,7 @@
-
126    inspect(z)
+
125    inspect(z)
@@ -356,11 +356,11 @@
-
128    inspect(a(z))
+            
127    inspect(a(z))
+128
 129
-130
-131if __name__ == '__main__':
-132    _test()
+130if __name__ == '__main__': +131 _test()

Denoising Diffusion Probabilistic Models (DDPM) training

-

Open In Colab Open In Comet

+

Open In Colab

This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA folder.

The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.

-
21from typing import List
-22
-23import torch
-24import torch.utils.data
-25import torchvision
-26from PIL import Image
-27
-28from labml import lab, tracker, experiment, monit
-29from labml.configs import BaseConfigs, option
-30from labml_helpers.device import DeviceConfigs
-31from labml_nn.diffusion.ddpm import DenoiseDiffusion
-32from labml_nn.diffusion.ddpm.unet import UNet
+
20from typing import List
+21
+22import torch
+23import torch.utils.data
+24import torchvision
+25from PIL import Image
+26
+27from labml import lab, tracker, experiment, monit
+28from labml.configs import BaseConfigs, option
+29from labml_helpers.device import DeviceConfigs
+30from labml_nn.diffusion.ddpm import DenoiseDiffusion
+31from labml_nn.diffusion.ddpm.unet import UNet
@@ -106,7 +106,7 @@
-
35class Configs(BaseConfigs):
+
34class Configs(BaseConfigs):
@@ -119,7 +119,7 @@
-
42    device: torch.device = DeviceConfigs()
+
41    device: torch.device = DeviceConfigs()
@@ -131,7 +131,7 @@
-
45    eps_model: UNet
+
44    eps_model: UNet
@@ -143,7 +143,7 @@
-
47    diffusion: DenoiseDiffusion
+
46    diffusion: DenoiseDiffusion
@@ -155,7 +155,7 @@
-
50    image_channels: int = 3
+
49    image_channels: int = 3
@@ -167,7 +167,7 @@
-
52    image_size: int = 32
+
51    image_size: int = 32
@@ -179,7 +179,7 @@
-
54    n_channels: int = 64
+
53    n_channels: int = 64
@@ -192,7 +192,7 @@
-
57    channel_multipliers: List[int] = [1, 2, 2, 4]
+
56    channel_multipliers: List[int] = [1, 2, 2, 4]
@@ -204,7 +204,7 @@
-
59    is_attention: List[int] = [False, False, False, True]
+
58    is_attention: List[int] = [False, False, False, True]
@@ -216,7 +216,7 @@
-
62    n_steps: int = 1_000
+
61    n_steps: int = 1_000
@@ -228,7 +228,7 @@
-
64    batch_size: int = 64
+
63    batch_size: int = 64
@@ -240,7 +240,7 @@
-
66    n_samples: int = 16
+
65    n_samples: int = 16
@@ -252,7 +252,7 @@
-
68    learning_rate: float = 2e-5
+
67    learning_rate: float = 2e-5
@@ -264,7 +264,7 @@
-
71    epochs: int = 1_000
+
70    epochs: int = 1_000
@@ -276,7 +276,7 @@
-
74    dataset: torch.utils.data.Dataset
+
73    dataset: torch.utils.data.Dataset
@@ -288,7 +288,7 @@
-
76    data_loader: torch.utils.data.DataLoader
+
75    data_loader: torch.utils.data.DataLoader
@@ -300,7 +300,7 @@
-
79    optimizer: torch.optim.Adam
+
78    optimizer: torch.optim.Adam
@@ -311,7 +311,7 @@
-
81    def init(self):
+
80    def init(self):
@@ -323,12 +323,12 @@
-
83        self.eps_model = UNet(
-84            image_channels=self.image_channels,
-85            n_channels=self.n_channels,
-86            ch_mults=self.channel_multipliers,
-87            is_attn=self.is_attention,
-88        ).to(self.device)
+
82        self.eps_model = UNet(
+83            image_channels=self.image_channels,
+84            n_channels=self.n_channels,
+85            ch_mults=self.channel_multipliers,
+86            is_attn=self.is_attention,
+87        ).to(self.device)
@@ -340,11 +340,11 @@
-
91        self.diffusion = DenoiseDiffusion(
-92            eps_model=self.eps_model,
-93            n_steps=self.n_steps,
-94            device=self.device,
-95        )
+
90        self.diffusion = DenoiseDiffusion(
+91            eps_model=self.eps_model,
+92            n_steps=self.n_steps,
+93            device=self.device,
+94        )
@@ -356,7 +356,7 @@
-
98        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
+
97        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
@@ -368,7 +368,7 @@
-
100        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
+
99        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
@@ -380,7 +380,7 @@
-
103        tracker.set_image("sample", True)
+
102        tracker.set_image("sample", True)
@@ -392,7 +392,7 @@
-
105    def sample(self):
+
104    def sample(self):
@@ -403,7 +403,7 @@
-
109        with torch.no_grad():
+
108        with torch.no_grad():
@@ -415,8 +415,8 @@
-
111            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
-112                            device=self.device)
+
110            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
+111                            device=self.device)
@@ -428,7 +428,7 @@
-
115            for t_ in monit.iterate('Sample', self.n_steps):
+
114            for t_ in monit.iterate('Sample', self.n_steps):
@@ -440,7 +440,7 @@
-
117                t = self.n_steps - t_ - 1
+
116                t = self.n_steps - t_ - 1
@@ -452,7 +452,7 @@
-
119                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
+
118                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
@@ -464,7 +464,7 @@
-
122            tracker.save('sample', x)
+
121            tracker.save('sample', x)
@@ -476,7 +476,7 @@
-
124    def train(self):
+
123    def train(self):
@@ -488,7 +488,7 @@
-
130        for data in monit.iterate('Train', self.data_loader):
+
129        for data in monit.iterate('Train', self.data_loader):
@@ -500,7 +500,7 @@
-
132            tracker.add_global_step()
+
131            tracker.add_global_step()
@@ -512,7 +512,7 @@
-
134            data = data.to(self.device)
+
133            data = data.to(self.device)
@@ -524,7 +524,7 @@
-
137            self.optimizer.zero_grad()
+
136            self.optimizer.zero_grad()
@@ -536,7 +536,7 @@
-
139            loss = self.diffusion.loss(data)
+
138            loss = self.diffusion.loss(data)
@@ -548,7 +548,7 @@
-
141            loss.backward()
+
140            loss.backward()
@@ -560,7 +560,7 @@
-
143            self.optimizer.step()
+
142            self.optimizer.step()
@@ -572,7 +572,7 @@
-
145            tracker.save('loss', loss)
+
144            tracker.save('loss', loss)
@@ -584,7 +584,7 @@
-
147    def run(self):
+
146    def run(self):
@@ -595,7 +595,7 @@
-
151        for _ in monit.loop(self.epochs):
+
150        for _ in monit.loop(self.epochs):
@@ -607,7 +607,7 @@
-
153            self.train()
+
152            self.train()
@@ -619,7 +619,7 @@
-
155            self.sample()
+
154            self.sample()
@@ -631,7 +631,7 @@
-
157            tracker.new_line()
+
156            tracker.new_line()
@@ -643,7 +643,7 @@
-
159            experiment.save_checkpoint()
+
158            experiment.save_checkpoint()
@@ -655,7 +655,7 @@
-
162class CelebADataset(torch.utils.data.Dataset):
+
161class CelebADataset(torch.utils.data.Dataset):
@@ -666,8 +666,8 @@
-
167    def __init__(self, image_size: int):
-168        super().__init__()
+
166    def __init__(self, image_size: int):
+167        super().__init__()
@@ -679,7 +679,7 @@
-
171        folder = lab.get_data_path() / 'celebA'
+
170        folder = lab.get_data_path() / 'celebA'
@@ -691,7 +691,7 @@
-
173        self._files = [p for p in folder.glob(f'**/*.jpg')]
+
172        self._files = [p for p in folder.glob(f'**/*.jpg')]
@@ -703,10 +703,10 @@
-
176        self._transform = torchvision.transforms.Compose([
-177            torchvision.transforms.Resize(image_size),
-178            torchvision.transforms.ToTensor(),
-179        ])
+
175        self._transform = torchvision.transforms.Compose([
+176            torchvision.transforms.Resize(image_size),
+177            torchvision.transforms.ToTensor(),
+178        ])
@@ -718,7 +718,7 @@
-
181    def __len__(self):
+
180    def __len__(self):
@@ -729,7 +729,7 @@
-
185        return len(self._files)
+
184        return len(self._files)
@@ -741,7 +741,7 @@
-
187    def __getitem__(self, index: int):
+
186    def __getitem__(self, index: int):
@@ -752,8 +752,8 @@
-
191        img = Image.open(self._files[index])
-192        return self._transform(img)
+
190        img = Image.open(self._files[index])
+191        return self._transform(img)
@@ -765,8 +765,8 @@
-
195@option(Configs.dataset, 'CelebA')
-196def celeb_dataset(c: Configs):
+
194@option(Configs.dataset, 'CelebA')
+195def celeb_dataset(c: Configs):
@@ -777,7 +777,7 @@
-
200    return CelebADataset(c.image_size)
+
199    return CelebADataset(c.image_size)
@@ -789,7 +789,7 @@
-
203class MNISTDataset(torchvision.datasets.MNIST):
+
202class MNISTDataset(torchvision.datasets.MNIST):
@@ -800,13 +800,13 @@
-
208    def __init__(self, image_size):
-209        transform = torchvision.transforms.Compose([
-210            torchvision.transforms.Resize(image_size),
-211            torchvision.transforms.ToTensor(),
-212        ])
-213
-214        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
+
207    def __init__(self, image_size):
+208        transform = torchvision.transforms.Compose([
+209            torchvision.transforms.Resize(image_size),
+210            torchvision.transforms.ToTensor(),
+211        ])
+212
+213        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
@@ -817,8 +817,8 @@
-
216    def __getitem__(self, item):
-217        return super().__getitem__(item)[0]
+
215    def __getitem__(self, item):
+216        return super().__getitem__(item)[0]
@@ -830,8 +830,8 @@
-
220@option(Configs.dataset, 'MNIST')
-221def mnist_dataset(c: Configs):
+
219@option(Configs.dataset, 'MNIST')
+220def mnist_dataset(c: Configs):
@@ -842,7 +842,7 @@
-
225    return MNISTDataset(c.image_size)
+
224    return MNISTDataset(c.image_size)
@@ -853,7 +853,7 @@
-
228def main():
+
227def main():
@@ -865,7 +865,7 @@
-
230    experiment.create(name='diffuse', writers={'screen', 'comet'})
+
229    experiment.create(name='diffuse', writers={'screen', 'labml'})
@@ -877,7 +877,7 @@
-
233    configs = Configs()
+
232    configs = Configs()
@@ -889,11 +889,11 @@
-
236    experiment.configs(configs, {
-237        'dataset': 'CelebA',  # 'MNIST'
-238        'image_channels': 3,  # 1,
-239        'epochs': 100,  # 5,
-240    })
+
235    experiment.configs(configs, {
+236        'dataset': 'CelebA',  # 'MNIST'
+237        'image_channels': 3,  # 1,
+238        'epochs': 100,  # 5,
+239    })
@@ -905,7 +905,7 @@
-
243    configs.init()
+
242    configs.init()
@@ -917,7 +917,7 @@
-
246    experiment.add_pytorch_models({'eps_model': configs.eps_model})
+
245    experiment.add_pytorch_models({'eps_model': configs.eps_model})
@@ -929,8 +929,8 @@
-
249    with experiment.start():
-250        configs.run()
+
248    with experiment.start():
+249        configs.run()
@@ -942,8 +942,8 @@
-
254if __name__ == '__main__':
-255    main()
+
253if __name__ == '__main__':
+254    main()

Denoising Diffusion Probabilistic Models (DDPM)

-

Open In Colab Open In Comet

+

Open In Colab

This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models.

In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.

The following definitions and derivations show how this works. For details please refer to the paper.

@@ -278,7 +278,7 @@ s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7 c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z M834 80h400000v40h-400000z">​ϵ,t)∥∥​2]​

That is, we are training to predict the noise.

Simplified loss

-

163from typing import Tuple, Optional
-164
-165import torch
-166import torch.nn.functional as F
-167import torch.utils.data
-168from torch import nn
-169
-170from labml_nn.diffusion.ddpm.utils import gather
+
162from typing import Tuple, Optional
+163
+164import torch
+165import torch.nn.functional as F
+166import torch.utils.data
+167from torch import nn
+168
+169from labml_nn.diffusion.ddpm.utils import gather
@@ -326,7 +326,7 @@ M834 80h400000v40h-400000z">
173class DenoiseDiffusion:
+
172class DenoiseDiffusion:
@@ -343,7 +343,7 @@ M834 80h400000v40h-400000z">
178    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
+
177    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
@@ -354,8 +354,8 @@ M834 80h400000v40h-400000z">
184        super().__init__()
-185        self.eps_model = eps_model
+
183        super().__init__()
+184        self.eps_model = eps_model
@@ -367,7 +367,7 @@ M834 80h400000v40h-400000z">
188        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
+
187        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
@@ -379,7 +379,7 @@ M834 80h400000v40h-400000z">
191        self.alpha = 1. - self.beta
+
190        self.alpha = 1. - self.beta
@@ -391,7 +391,7 @@ M834 80h400000v40h-400000z">
193        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
+
192        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
@@ -403,7 +403,7 @@ M834 80h400000v40h-400000z">
195        self.n_steps = n_steps
+
194        self.n_steps = n_steps
@@ -415,7 +415,7 @@ M834 80h400000v40h-400000z">
197        self.sigma2 = self.beta
+
196        self.sigma2 = self.beta
@@ -438,7 +438,7 @@ c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z M834 80h400000v40h-400000z">​x0​,(1−αt​ˉ​)I)​
-
199    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+
198    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -461,7 +461,7 @@ M834 80h400000v40h-400000z">
209        mean = gather(self.alpha_bar, t) ** 0.5 * x0
+
208        mean = gather(self.alpha_bar, t) ** 0.5 * x0
@@ -473,7 +473,7 @@ M834 80h400000v40h-400000z">
211        var = 1 - gather(self.alpha_bar, t)
+
210        var = 1 - gather(self.alpha_bar, t)
@@ -485,7 +485,7 @@ M834 80h400000v40h-400000z">
213        return mean, var
+
212        return mean, var
@@ -508,7 +508,7 @@ c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z M834 80h400000v40h-400000z">​x0​,(1−αt​ˉ​)I)​
-
215    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
+
214    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
@@ -520,8 +520,8 @@ M834 80h400000v40h-400000z">
225        if eps is None:
-226            eps = torch.randn_like(x0)
+
224        if eps is None:
+225            eps = torch.randn_like(x0)
@@ -533,7 +533,7 @@ M834 80h400000v40h-400000z">
229        mean, var = self.q_xt_x0(x0, t)
+
228        mean, var = self.q_xt_x0(x0, t)
@@ -545,7 +545,7 @@ M834 80h400000v40h-400000z">
231        return mean + (var ** 0.5) * eps
+
230        return mean + (var ** 0.5) * eps
@@ -579,7 +579,7 @@ c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z M834 80h400000v40h-400000z">​βt​​ϵθ​(xt​,t))​
-
233    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
+
232    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
@@ -591,7 +591,7 @@ M834 80h400000v40h-400000z">
247        eps_theta = self.eps_model(xt, t)
+
246        eps_theta = self.eps_model(xt, t)
@@ -603,7 +603,7 @@ M834 80h400000v40h-400000z">
249        alpha_bar = gather(self.alpha_bar, t)
+
248        alpha_bar = gather(self.alpha_bar, t)
@@ -615,7 +615,7 @@ M834 80h400000v40h-400000z">
251        alpha = gather(self.alpha, t)
+
250        alpha = gather(self.alpha, t)
@@ -638,7 +638,7 @@ M834 80h400000v40h-400000z">
253        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
+
252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
@@ -672,7 +672,7 @@ M834 80h400000v40h-400000z">
256        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
+
255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
@@ -684,7 +684,7 @@ M834 80h400000v40h-400000z">
258        var = gather(self.sigma2, t)
+
257        var = gather(self.sigma2, t)
@@ -696,7 +696,7 @@ M834 80h400000v40h-400000z">
261        eps = torch.randn(xt.shape, device=xt.device)
+
260        eps = torch.randn(xt.shape, device=xt.device)
@@ -708,7 +708,7 @@ M834 80h400000v40h-400000z">
263        return mean + (var ** .5) * eps
+
262        return mean + (var ** .5) * eps
@@ -717,7 +717,7 @@ M834 80h400000v40h-400000z">
+
264    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
@@ -755,7 +755,7 @@ M834 80h400000v40h-400000z">
274        batch_size = x0.shape[0]
+
273        batch_size = x0.shape[0]
@@ -767,7 +767,7 @@ M834 80h400000v40h-400000z">
276        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
+
275        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
@@ -779,8 +779,8 @@ M834 80h400000v40h-400000z">
279        if noise is None:
-280            noise = torch.randn_like(x0)
+
278        if noise is None:
+279            noise = torch.randn_like(x0)
@@ -792,7 +792,7 @@ M834 80h400000v40h-400000z">
283        xt = self.q_sample(x0, t, eps=noise)
+
282        xt = self.q_sample(x0, t, eps=noise)
@@ -826,7 +826,7 @@ M834 80h400000v40h-400000z">
285        eps_theta = self.eps_model(xt, t)
+
284        eps_theta = self.eps_model(xt, t)
@@ -838,7 +838,7 @@ M834 80h400000v40h-400000z">
288        return F.mse_loss(noise, eps_theta)
+
287        return F.mse_loss(noise, eps_theta)

Denoising Diffusion Probabilistic Models (DDPM)

-

Open In Colab Open In Comet

+

Open In Colab

This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models.

In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.

Here is the UNet model that predicts the noise and training code. This file can generate samples and interpolations from a trained model.

diff --git a/docs/diffusion/ddpm/unet.html b/docs/diffusion/ddpm/unet.html index 4eabaf42..1c2109ce 100644 --- a/docs/diffusion/ddpm/unet.html +++ b/docs/diffusion/ddpm/unet.html @@ -1,5 +1,5 @@ - + diff --git a/docs/diffusion/ddpm/utils.html b/docs/diffusion/ddpm/utils.html index 95995e82..a1e25c23 100644 --- a/docs/diffusion/ddpm/utils.html +++ b/docs/diffusion/ddpm/utils.html @@ -1,5 +1,5 @@ - + diff --git a/docs/diffusion/index.html b/docs/diffusion/index.html index a9e9ba1c..cef3d8d4 100644 --- a/docs/diffusion/index.html +++ b/docs/diffusion/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/distillation/index.html b/docs/distillation/index.html index d6f4c942..506d6787 100644 --- a/docs/distillation/index.html +++ b/docs/distillation/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/distillation/large.html b/docs/distillation/large.html index 840dbe3f..bca72b2d 100644 --- a/docs/distillation/large.html +++ b/docs/distillation/large.html @@ -1,5 +1,5 @@ - + diff --git a/docs/distillation/readme.html b/docs/distillation/readme.html index 74c52600..f296823f 100644 --- a/docs/distillation/readme.html +++ b/docs/distillation/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/distillation/small.html b/docs/distillation/small.html index dbcafdd6..ab221183 100644 --- a/docs/distillation/small.html +++ b/docs/distillation/small.html @@ -1,5 +1,5 @@ - + diff --git a/docs/experiments/arithmetic_dataset.html b/docs/experiments/arithmetic_dataset.html index 2b6880f4..b4af9a75 100644 --- a/docs/experiments/arithmetic_dataset.html +++ b/docs/experiments/arithmetic_dataset.html @@ -1,5 +1,5 @@ - + diff --git a/docs/experiments/cifar10.html b/docs/experiments/cifar10.html index 23853103..ee48359e 100644 --- a/docs/experiments/cifar10.html +++ b/docs/experiments/cifar10.html @@ -1,5 +1,5 @@ - + diff --git a/docs/experiments/index.html b/docs/experiments/index.html index bf1253f4..29b75e7a 100644 --- a/docs/experiments/index.html +++ b/docs/experiments/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/experiments/mnist.html b/docs/experiments/mnist.html index 5a50610c..01c5c7f0 100644 --- a/docs/experiments/mnist.html +++ b/docs/experiments/mnist.html @@ -1,5 +1,5 @@ - + diff --git a/docs/experiments/nlp_autoregression.html b/docs/experiments/nlp_autoregression.html index f1ee53d3..c3d46a70 100644 --- a/docs/experiments/nlp_autoregression.html +++ b/docs/experiments/nlp_autoregression.html @@ -1,5 +1,5 @@ - + diff --git a/docs/experiments/nlp_classification.html b/docs/experiments/nlp_classification.html index 6fde095c..a57bb5ee 100644 --- a/docs/experiments/nlp_classification.html +++ b/docs/experiments/nlp_classification.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/cycle_gan/index.html b/docs/gan/cycle_gan/index.html index 48b7deab..9bf9c7bc 100644 --- a/docs/gan/cycle_gan/index.html +++ b/docs/gan/cycle_gan/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/cycle_gan/readme.html b/docs/gan/cycle_gan/readme.html index 2aee81a7..811521cc 100644 --- a/docs/gan/cycle_gan/readme.html +++ b/docs/gan/cycle_gan/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/dcgan/index.html b/docs/gan/dcgan/index.html index dc5492c3..48857c71 100644 --- a/docs/gan/dcgan/index.html +++ b/docs/gan/dcgan/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/dcgan/readme.html b/docs/gan/dcgan/readme.html index 394b9448..91207ece 100644 --- a/docs/gan/dcgan/readme.html +++ b/docs/gan/dcgan/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/index.html b/docs/gan/index.html index 14bfcfbe..4983e0c8 100644 --- a/docs/gan/index.html +++ b/docs/gan/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/original/experiment.html b/docs/gan/original/experiment.html index 5629f146..03b2a853 100644 --- a/docs/gan/original/experiment.html +++ b/docs/gan/original/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/original/index.html b/docs/gan/original/index.html index 0352128b..1fe45cb9 100644 --- a/docs/gan/original/index.html +++ b/docs/gan/original/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/original/readme.html b/docs/gan/original/readme.html index 6ca6cf47..2fb2c9b7 100644 --- a/docs/gan/original/readme.html +++ b/docs/gan/original/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/stylegan/experiment.html b/docs/gan/stylegan/experiment.html index a0dd4f62..16fcfb4e 100644 --- a/docs/gan/stylegan/experiment.html +++ b/docs/gan/stylegan/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/stylegan/index.html b/docs/gan/stylegan/index.html index eecbe12d..f44a19d1 100644 --- a/docs/gan/stylegan/index.html +++ b/docs/gan/stylegan/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/stylegan/readme.html b/docs/gan/stylegan/readme.html index 8a917138..37dd4d53 100644 --- a/docs/gan/stylegan/readme.html +++ b/docs/gan/stylegan/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/wasserstein/experiment.html b/docs/gan/wasserstein/experiment.html index b71c631f..482867e2 100644 --- a/docs/gan/wasserstein/experiment.html +++ b/docs/gan/wasserstein/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/wasserstein/gradient_penalty/experiment.html b/docs/gan/wasserstein/gradient_penalty/experiment.html index 9b544807..16e2889b 100644 --- a/docs/gan/wasserstein/gradient_penalty/experiment.html +++ b/docs/gan/wasserstein/gradient_penalty/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/wasserstein/gradient_penalty/index.html b/docs/gan/wasserstein/gradient_penalty/index.html index ecf739dc..432392e7 100644 --- a/docs/gan/wasserstein/gradient_penalty/index.html +++ b/docs/gan/wasserstein/gradient_penalty/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/wasserstein/gradient_penalty/readme.html b/docs/gan/wasserstein/gradient_penalty/readme.html index d1b446a4..0447fbc6 100644 --- a/docs/gan/wasserstein/gradient_penalty/readme.html +++ b/docs/gan/wasserstein/gradient_penalty/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/wasserstein/index.html b/docs/gan/wasserstein/index.html index 5735c7c5..2baf8219 100644 --- a/docs/gan/wasserstein/index.html +++ b/docs/gan/wasserstein/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/gan/wasserstein/readme.html b/docs/gan/wasserstein/readme.html index 88b99582..cbb6d5e2 100644 --- a/docs/gan/wasserstein/readme.html +++ b/docs/gan/wasserstein/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/graphs/gat/experiment.html b/docs/graphs/gat/experiment.html index 0eb6929e..eb5eb6c7 100644 --- a/docs/graphs/gat/experiment.html +++ b/docs/graphs/gat/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/graphs/gat/index.html b/docs/graphs/gat/index.html index 05600b81..2a66954c 100644 --- a/docs/graphs/gat/index.html +++ b/docs/graphs/gat/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/graphs/gat/readme.html b/docs/graphs/gat/readme.html index c7ac59e1..9d7662fd 100644 --- a/docs/graphs/gat/readme.html +++ b/docs/graphs/gat/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/graphs/gatv2/experiment.html b/docs/graphs/gatv2/experiment.html index 4877a782..b7e89d4a 100644 --- a/docs/graphs/gatv2/experiment.html +++ b/docs/graphs/gatv2/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/graphs/gatv2/index.html b/docs/graphs/gatv2/index.html index 8e596a02..06e7693e 100644 --- a/docs/graphs/gatv2/index.html +++ b/docs/graphs/gatv2/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/graphs/gatv2/readme.html b/docs/graphs/gatv2/readme.html index 46e307be..220179db 100644 --- a/docs/graphs/gatv2/readme.html +++ b/docs/graphs/gatv2/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/graphs/index.html b/docs/graphs/index.html index b1674934..05248b04 100644 --- a/docs/graphs/index.html +++ b/docs/graphs/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/hypernetworks/experiment.html b/docs/hypernetworks/experiment.html index 78d5273e..9b3c1f52 100644 --- a/docs/hypernetworks/experiment.html +++ b/docs/hypernetworks/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/hypernetworks/hyper_lstm.html b/docs/hypernetworks/hyper_lstm.html index 0a55f7ca..4cde7abb 100644 --- a/docs/hypernetworks/hyper_lstm.html +++ b/docs/hypernetworks/hyper_lstm.html @@ -1,5 +1,5 @@ - + diff --git a/docs/hypernetworks/index.html b/docs/hypernetworks/index.html index cfb06468..009263d1 100644 --- a/docs/hypernetworks/index.html +++ b/docs/hypernetworks/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/index.html b/docs/index.html index 677273bb..1f06b6d0 100644 --- a/docs/index.html +++ b/docs/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/lstm/index.html b/docs/lstm/index.html index 9dd24040..8225d9d6 100644 --- a/docs/lstm/index.html +++ b/docs/lstm/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/checkpoint.html b/docs/neox/checkpoint.html index ebfc64b4..d471dffa 100644 --- a/docs/neox/checkpoint.html +++ b/docs/neox/checkpoint.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/evaluation/half_precision.html b/docs/neox/evaluation/half_precision.html index 0df43082..0d9774da 100644 --- a/docs/neox/evaluation/half_precision.html +++ b/docs/neox/evaluation/half_precision.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/evaluation/index.html b/docs/neox/evaluation/index.html index 1cac38ce..2c61e919 100644 --- a/docs/neox/evaluation/index.html +++ b/docs/neox/evaluation/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/evaluation/llm_int8.html b/docs/neox/evaluation/llm_int8.html index b12f2a83..cff95374 100644 --- a/docs/neox/evaluation/llm_int8.html +++ b/docs/neox/evaluation/llm_int8.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/index.html b/docs/neox/index.html index 43f79328..f7922ddf 100644 --- a/docs/neox/index.html +++ b/docs/neox/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/model.html b/docs/neox/model.html index b0906adc..9092a600 100644 --- a/docs/neox/model.html +++ b/docs/neox/model.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/readme.html b/docs/neox/readme.html index 149419b7..91612c7f 100644 --- a/docs/neox/readme.html +++ b/docs/neox/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/samples/finetune.html b/docs/neox/samples/finetune.html index fb5d7435..81def7f7 100644 --- a/docs/neox/samples/finetune.html +++ b/docs/neox/samples/finetune.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/samples/generate.html b/docs/neox/samples/generate.html index 8228b9a2..5e84f046 100644 --- a/docs/neox/samples/generate.html +++ b/docs/neox/samples/generate.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/samples/index.html b/docs/neox/samples/index.html index cac8a13e..93168d6f 100644 --- a/docs/neox/samples/index.html +++ b/docs/neox/samples/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/samples/llm_int8.html b/docs/neox/samples/llm_int8.html index b3d024dd..85111093 100644 --- a/docs/neox/samples/llm_int8.html +++ b/docs/neox/samples/llm_int8.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/tokenizer.html b/docs/neox/tokenizer.html index 420b22a2..0308d2fc 100644 --- a/docs/neox/tokenizer.html +++ b/docs/neox/tokenizer.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/utils/cache.html b/docs/neox/utils/cache.html index aad9ee87..96711f8a 100644 --- a/docs/neox/utils/cache.html +++ b/docs/neox/utils/cache.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/utils/finetune.html b/docs/neox/utils/finetune.html index e11a8813..86370aac 100644 --- a/docs/neox/utils/finetune.html +++ b/docs/neox/utils/finetune.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/utils/index.html b/docs/neox/utils/index.html index 0806cc80..7e144c03 100644 --- a/docs/neox/utils/index.html +++ b/docs/neox/utils/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/utils/llm_int8.html b/docs/neox/utils/llm_int8.html index a4e12c6c..8588fdb4 100644 --- a/docs/neox/utils/llm_int8.html +++ b/docs/neox/utils/llm_int8.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/utils/text_dataset.html b/docs/neox/utils/text_dataset.html index c1b493a4..09479080 100644 --- a/docs/neox/utils/text_dataset.html +++ b/docs/neox/utils/text_dataset.html @@ -1,5 +1,5 @@ - + diff --git a/docs/neox/utils/trainer.html b/docs/neox/utils/trainer.html index 37a53cbc..d43bce74 100644 --- a/docs/neox/utils/trainer.html +++ b/docs/neox/utils/trainer.html @@ -1,5 +1,5 @@ - + diff --git a/docs/normalization/batch_channel_norm/index.html b/docs/normalization/batch_channel_norm/index.html index dbab9406..1f1799c4 100644 --- a/docs/normalization/batch_channel_norm/index.html +++ b/docs/normalization/batch_channel_norm/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/normalization/batch_norm/cifar10.html b/docs/normalization/batch_norm/cifar10.html index 211ef380..fbadc07e 100644 --- a/docs/normalization/batch_norm/cifar10.html +++ b/docs/normalization/batch_norm/cifar10.html @@ -1,5 +1,5 @@ - + diff --git a/docs/normalization/batch_norm/index.html b/docs/normalization/batch_norm/index.html index be959d8c..0d52615d 100644 --- a/docs/normalization/batch_norm/index.html +++ b/docs/normalization/batch_norm/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/normalization/batch_norm/mnist.html b/docs/normalization/batch_norm/mnist.html index 7cb1b069..3ea3eb98 100644 --- a/docs/normalization/batch_norm/mnist.html +++ b/docs/normalization/batch_norm/mnist.html @@ -1,5 +1,5 @@ - + diff --git a/docs/normalization/batch_norm/readme.html b/docs/normalization/batch_norm/readme.html index 13394ada..706e623e 100644 --- a/docs/normalization/batch_norm/readme.html +++ b/docs/normalization/batch_norm/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/normalization/deep_norm/experiment.html b/docs/normalization/deep_norm/experiment.html index f2fe3b91..e1ba372a 100644 --- a/docs/normalization/deep_norm/experiment.html +++ b/docs/normalization/deep_norm/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/normalization/deep_norm/index.html b/docs/normalization/deep_norm/index.html index 498c4919..242cedb3 100644 --- a/docs/normalization/deep_norm/index.html +++ b/docs/normalization/deep_norm/index.html @@ -1,5 +1,5 @@ - + @@ -76,7 +76,7 @@ #

DeepNorm

-

Open In Colab Open In Comet

+

Open In Colab

This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers.

The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.

The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.

@@ -95,15 +95,15 @@
-
74from typing import Union, List
-75
-76import torch
-77from torch import nn, Size
-78
-79from labml_nn.normalization.layer_norm import LayerNorm
-80from labml_nn.transformers import MultiHeadAttention
-81from labml_nn.transformers.feed_forward import FeedForward
-82from labml_nn.transformers.utils import subsequent_mask
+
73from typing import Union, List
+74
+75import torch
+76from torch import nn, Size
+77
+78from labml_nn.normalization.layer_norm import LayerNorm
+79from labml_nn.transformers import MultiHeadAttention
+80from labml_nn.transformers.feed_forward import FeedForward
+81from labml_nn.transformers.utils import subsequent_mask
@@ -116,7 +116,7 @@
-
85class DeepNorm(nn.Module):
+
84class DeepNorm(nn.Module):
@@ -135,9 +135,9 @@
-
92    def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
-93                 eps: float = 1e-5,
-94                 elementwise_affine: bool = True):
+
91    def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
+92                 eps: float = 1e-5,
+93                 elementwise_affine: bool = True):
@@ -148,9 +148,9 @@
-
101        super().__init__()
-102
-103        self.alpha = alpha
+
100        super().__init__()
+101
+102        self.alpha = alpha
@@ -162,7 +162,7 @@
-
105        self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
+
104        self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
@@ -177,7 +177,7 @@
-
107    def forward(self, x: torch.Tensor, gx: torch.Tensor):
+
106    def forward(self, x: torch.Tensor, gx: torch.Tensor):
@@ -189,7 +189,7 @@
-
113        return self.layer_norm(x + self.alpha * gx)
+
112        return self.layer_norm(x + self.alpha * gx)
@@ -202,7 +202,7 @@
-
116class DeepNormTransformerLayer(nn.Module):
+
115class DeepNormTransformerLayer(nn.Module):
@@ -223,13 +223,13 @@
-
123    def __init__(self, *,
-124                 d_model: int,
-125                 self_attn: MultiHeadAttention,
-126                 feed_forward: FeedForward,
-127                 deep_norm_alpha: float,
-128                 deep_norm_beta: float,
-129                 ):
+
122    def __init__(self, *,
+123                 d_model: int,
+124                 self_attn: MultiHeadAttention,
+125                 feed_forward: FeedForward,
+126                 deep_norm_alpha: float,
+127                 deep_norm_beta: float,
+128                 ):
@@ -240,10 +240,10 @@
-
137        super().__init__()
-138
-139        self.self_attn = self_attn
-140        self.feed_forward = feed_forward
+
136        super().__init__()
+137
+138        self.self_attn = self_attn
+139        self.feed_forward = feed_forward
@@ -255,8 +255,8 @@
-
142        self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
-143        self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])
+
141        self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
+142        self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])
@@ -268,7 +268,7 @@
-
146        with torch.no_grad():
+
145        with torch.no_grad():
@@ -280,8 +280,8 @@
-
148            feed_forward.layer1.weight *= deep_norm_beta
-149            feed_forward.layer2.weight *= deep_norm_beta
+
147            feed_forward.layer1.weight *= deep_norm_beta
+148            feed_forward.layer2.weight *= deep_norm_beta
@@ -293,7 +293,7 @@
-
152            self_attn.value.linear.weight *= deep_norm_beta
+
151            self_attn.value.linear.weight *= deep_norm_beta
@@ -305,7 +305,7 @@
-
154            self_attn.output.weight *= deep_norm_beta
+
153            self_attn.output.weight *= deep_norm_beta
@@ -317,7 +317,7 @@
-
157        self.mask = None
+
156        self.mask = None
@@ -331,7 +331,7 @@
-
159    def forward(self, x: torch.Tensor):
+
158    def forward(self, x: torch.Tensor):
@@ -343,7 +343,7 @@
-
164        if self.mask is None or self.mask.size(0) != len(x):
+
163        if self.mask is None or self.mask.size(0) != len(x):
@@ -355,7 +355,7 @@
-
166            self.mask = subsequent_mask(len(x)).to(x.device)
+
165            self.mask = subsequent_mask(len(x)).to(x.device)
@@ -367,7 +367,7 @@
-
169        x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))
+
168        x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))
@@ -379,7 +379,7 @@
-
171        x = self.feed_forward_norm(x, self.feed_forward(x))
+
170        x = self.feed_forward_norm(x, self.feed_forward(x))
@@ -391,7 +391,7 @@
-
174        return x
+
173        return x

Transformer Auto-Regression Experiment

-

Open In Colab Open In Comet

+

Open In Colab

This trains a simple transformer introduced in Attention Is All You Need on an NLP auto-regression task (with Tiny Shakespeare dataset).

-
17import torch
-18from torch import nn
-19
-20from labml import experiment
-21from labml.configs import option
-22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-23from labml_nn.transformers import TransformerConfigs, Encoder
-24from labml_nn.transformers.utils import subsequent_mask
+
16import torch
+17from torch import nn
+18
+19from labml import experiment
+20from labml.configs import option
+21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+22from labml_nn.transformers import TransformerConfigs, Encoder
+23from labml_nn.transformers.utils import subsequent_mask
@@ -100,7 +100,7 @@
-
27class AutoregressiveTransformer(nn.Module):
+
26class AutoregressiveTransformer(nn.Module):
@@ -117,7 +117,7 @@
-
31    def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):
+
30    def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):
@@ -128,10 +128,10 @@
-
38        super().__init__()
-39        self.src_embed = src_embed
-40        self.encoder = encoder
-41        self.generator = generator
+
37        super().__init__()
+38        self.src_embed = src_embed
+39        self.encoder = encoder
+40        self.generator = generator
@@ -143,7 +143,7 @@
-
44        self.mask = None
+
43        self.mask = None
@@ -154,7 +154,7 @@
-
46    def forward(self, x: torch.Tensor):
+
45    def forward(self, x: torch.Tensor):
@@ -166,7 +166,7 @@
-
49        if self.mask is None or self.mask.size(0) != len(x):
+
48        if self.mask is None or self.mask.size(0) != len(x):
@@ -178,7 +178,7 @@
-
51            self.mask = subsequent_mask(len(x)).to(x.device)
+
50            self.mask = subsequent_mask(len(x)).to(x.device)
@@ -190,7 +190,7 @@
-
53        x = self.src_embed(x)
+
52        x = self.src_embed(x)
@@ -202,7 +202,7 @@
-
55        x = self.encoder(x, self.mask)
+
54        x = self.encoder(x, self.mask)
@@ -214,7 +214,7 @@
-
57        x = self.generator(x)
+
56        x = self.generator(x)
@@ -226,7 +226,7 @@
-
61        return x, None
+
60        return x, None
@@ -240,7 +240,7 @@
-
64class Configs(NLPAutoRegressionConfigs):
+
63class Configs(NLPAutoRegressionConfigs):
@@ -252,7 +252,7 @@
-
73    model: AutoregressiveTransformer
+
72    model: AutoregressiveTransformer
@@ -264,7 +264,7 @@
-
75    transformer: TransformerConfigs
+
74    transformer: TransformerConfigs
@@ -276,8 +276,8 @@
-
78@option(Configs.transformer, 'Transformer')
-79def _transformer_configs(c: Configs):
+
77@option(Configs.transformer, 'Transformer')
+78def _transformer_configs(c: Configs):
@@ -289,7 +289,7 @@
-
86    conf = TransformerConfigs()
+
85    conf = TransformerConfigs()
@@ -301,8 +301,8 @@
-
88    conf.n_src_vocab = c.n_tokens
-89    conf.n_tgt_vocab = c.n_tokens
+
87    conf.n_src_vocab = c.n_tokens
+88    conf.n_tgt_vocab = c.n_tokens
@@ -314,7 +314,7 @@
-
91    conf.d_model = c.d_model
+
90    conf.d_model = c.d_model
@@ -326,7 +326,7 @@
-
94    return conf
+
93    return conf
@@ -338,8 +338,8 @@
-
97@option(Configs.model)
-98def _model(c: Configs):
+
96@option(Configs.model)
+97def _model(c: Configs):
@@ -350,11 +350,11 @@
-
102    m = AutoregressiveTransformer(c.transformer.encoder,
-103                                  c.transformer.src_embed,
-104                                  c.transformer.generator).to(c.device)
-105
-106    return m
+
101    m = AutoregressiveTransformer(c.transformer.encoder,
+102                                  c.transformer.src_embed,
+103                                  c.transformer.generator).to(c.device)
+104
+105    return m
@@ -365,7 +365,7 @@
-
109def main():
+
108def main():
@@ -377,7 +377,7 @@
-
111    experiment.create(name="transformer")
+
110    experiment.create(name="transformer")
@@ -389,7 +389,7 @@
-
113    conf = Configs()
+
112    conf = Configs()
@@ -401,7 +401,7 @@
-
115    experiment.configs(conf, {
+
114    experiment.configs(conf, {
@@ -413,7 +413,7 @@
-
117        'tokenizer': 'character',
+
116        'tokenizer': 'character',
@@ -425,7 +425,7 @@
-
119        'prompt_separator': '',
+
118        'prompt_separator': '',
@@ -437,7 +437,7 @@
-
121        'prompt': 'It is ',
+
120        'prompt': 'It is ',
@@ -449,7 +449,7 @@
-
123        'text': 'tiny_shakespeare',
+
122        'text': 'tiny_shakespeare',
@@ -461,7 +461,7 @@
-
126        'seq_len': 512,
+
125        'seq_len': 512,
@@ -473,7 +473,7 @@
-
128        'epochs': 32,
+
127        'epochs': 32,
@@ -485,7 +485,7 @@
-
130        'batch_size': 16,
+
129        'batch_size': 16,
@@ -497,7 +497,7 @@
-
133        'inner_iterations': 10,
+
132        'inner_iterations': 10,
@@ -509,9 +509,9 @@
-
136        'd_model': 256,
-137        'transformer.n_heads': 16,
-138        'transformer.ffn.d_ff': 1024,
+
135        'd_model': 256,
+136        'transformer.n_heads': 16,
+137        'transformer.ffn.d_ff': 1024,
@@ -523,9 +523,9 @@
-
141        'optimizer.optimizer': 'Noam',
-142        'optimizer.learning_rate': 1.,
-143    })
+
140        'optimizer.optimizer': 'Noam',
+141        'optimizer.learning_rate': 1.,
+142    })
@@ -537,7 +537,7 @@
-
146    experiment.add_pytorch_models({'model': conf.model})
+
145    experiment.add_pytorch_models({'model': conf.model})
@@ -549,7 +549,7 @@
-
149    with experiment.start():
+
148    with experiment.start():
@@ -561,7 +561,7 @@
-
151        conf.run()
+
150        conf.run()
@@ -573,8 +573,8 @@
-
155if __name__ == '__main__':
-156    main()
+
154if __name__ == '__main__':
+155    main()

Multi-Headed Attention (MHA)

-

Open In Colab Open In Comet

+

Open In Colab

This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.

Here is the training code that uses a basic transformer with MHA for NLP auto-regression.

Here is an experiment implementation that trains a simple transformer.

-
25import math
-26from typing import Optional, List
-27
-28import torch
-29from torch import nn
-30
-31from labml import tracker
+
24import math
+25from typing import Optional, List
+26
+27import torch
+28from torch import nn
+29
+30from labml import tracker
@@ -102,7 +102,7 @@
-
34class PrepareForMultiHeadAttention(nn.Module):
+
33class PrepareForMultiHeadAttention(nn.Module):
@@ -113,8 +113,8 @@
-
45    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
-46        super().__init__()
+
44    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
+45        super().__init__()
@@ -126,7 +126,7 @@
-
48        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
+
47        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
@@ -138,7 +138,7 @@
-
50        self.heads = heads
+
49        self.heads = heads
@@ -150,7 +150,7 @@
-
52        self.d_k = d_k
+
51        self.d_k = d_k
@@ -161,7 +161,7 @@
-
54    def forward(self, x: torch.Tensor):
+
53    def forward(self, x: torch.Tensor):
@@ -175,7 +175,7 @@
-
58        head_shape = x.shape[:-1]
+
57        head_shape = x.shape[:-1]
@@ -187,7 +187,7 @@
-
61        x = self.linear(x)
+
60        x = self.linear(x)
@@ -199,7 +199,7 @@
-
64        x = x.view(*head_shape, self.heads, self.d_k)
+
63        x = x.view(*head_shape, self.heads, self.d_k)
@@ -213,7 +213,7 @@
-
67        return x
+
66        return x
@@ -256,7 +256,7 @@ M834 80h400000v40h-400000z">
70class MultiHeadAttention(nn.Module):
+
69class MultiHeadAttention(nn.Module):
@@ -274,7 +274,7 @@ M834 80h400000v40h-400000z">
91    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
+
90    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
@@ -285,7 +285,7 @@ M834 80h400000v40h-400000z">
97        super().__init__()
+
96        super().__init__()
@@ -297,7 +297,7 @@ M834 80h400000v40h-400000z">
100        self.d_k = d_model // heads
+
99        self.d_k = d_model // heads
@@ -309,7 +309,7 @@ M834 80h400000v40h-400000z">
102        self.heads = heads
+
101        self.heads = heads
@@ -324,9 +324,9 @@ M834 80h400000v40h-400000z">
105        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
-106        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
-107        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
+
104        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+105        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+106        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
@@ -339,7 +339,7 @@ M834 80h400000v40h-400000z">
110        self.softmax = nn.Softmax(dim=1)
+
109        self.softmax = nn.Softmax(dim=1)
@@ -351,7 +351,7 @@ M834 80h400000v40h-400000z">
113        self.output = nn.Linear(d_model, d_model)
+
112        self.output = nn.Linear(d_model, d_model)
@@ -363,7 +363,7 @@ M834 80h400000v40h-400000z">
115        self.dropout = nn.Dropout(dropout_prob)
+
114        self.dropout = nn.Dropout(dropout_prob)
@@ -375,7 +375,7 @@ M834 80h400000v40h-400000z">
117        self.scale = 1 / math.sqrt(self.d_k)
+
116        self.scale = 1 / math.sqrt(self.d_k)
@@ -387,7 +387,7 @@ M834 80h400000v40h-400000z">
120        self.attn = None
+
119        self.attn = None
@@ -400,7 +400,7 @@ M834 80h400000v40h-400000z">
122    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
121    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@@ -412,7 +412,7 @@ M834 80h400000v40h-400000z">
130        return torch.einsum('ibhd,jbhd->ijbh', query, key)
+
129        return torch.einsum('ibhd,jbhd->ijbh', query, key)
@@ -426,7 +426,7 @@ M834 80h400000v40h-400000z">
132    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
+
131    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
@@ -437,9 +437,9 @@ M834 80h400000v40h-400000z">
138        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
-139        assert mask.shape[1] == key_shape[0]
-140        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
+
137        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
+138        assert mask.shape[1] == key_shape[0]
+139        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
@@ -451,7 +451,7 @@ M834 80h400000v40h-400000z">
143        mask = mask.unsqueeze(-1)
+
142        mask = mask.unsqueeze(-1)
@@ -464,7 +464,7 @@ M834 80h400000v40h-400000z">
146        return mask
+
145        return mask
@@ -487,11 +487,11 @@ M834 80h400000v40h-400000z">
148    def forward(self, *,
-149                query: torch.Tensor,
-150                key: torch.Tensor,
-151                value: torch.Tensor,
-152                mask: Optional[torch.Tensor] = None):
+
147    def forward(self, *,
+148                query: torch.Tensor,
+149                key: torch.Tensor,
+150                value: torch.Tensor,
+151                mask: Optional[torch.Tensor] = None):
@@ -507,10 +507,10 @@ M834 80h400000v40h-400000z">
164        seq_len, batch_size, _ = query.shape
-165
-166        if mask is not None:
-167            mask = self.prepare_mask(mask, query.shape, key.shape)
+
163        seq_len, batch_size, _ = query.shape
+164
+165        if mask is not None:
+166            mask = self.prepare_mask(mask, query.shape, key.shape)
@@ -526,9 +526,9 @@ M834 80h400000v40h-400000z">
171        query = self.query(query)
-172        key = self.key(key)
-173        value = self.value(value)
+
170        query = self.query(query)
+171        key = self.key(key)
+172        value = self.value(value)
@@ -541,7 +541,7 @@ M834 80h400000v40h-400000z">
177        scores = self.get_scores(query, key)
+
176        scores = self.get_scores(query, key)
@@ -564,7 +564,7 @@ M834 80h400000v40h-400000z">
180        scores *= self.scale
+
179        scores *= self.scale
@@ -576,8 +576,8 @@ M834 80h400000v40h-400000z">
183        if mask is not None:
-184            scores = scores.masked_fill(mask == 0, float('-inf'))
+
182        if mask is not None:
+183            scores = scores.masked_fill(mask == 0, float('-inf'))
@@ -600,7 +600,7 @@ M834 80h400000v40h-400000z">
188        attn = self.softmax(scores)
+
187        attn = self.softmax(scores)
@@ -612,7 +612,7 @@ M834 80h400000v40h-400000z">
191        tracker.debug('attn', attn)
+
190        tracker.debug('attn', attn)
@@ -624,7 +624,7 @@ M834 80h400000v40h-400000z">
194        attn = self.dropout(attn)
+
193        attn = self.dropout(attn)
@@ -647,7 +647,7 @@ M834 80h400000v40h-400000z">
198        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
+
197        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
@@ -659,7 +659,7 @@ M834 80h400000v40h-400000z">
201        self.attn = attn.detach()
+
200        self.attn = attn.detach()
@@ -671,7 +671,7 @@ M834 80h400000v40h-400000z">
204        x = x.reshape(seq_len, batch_size, -1)
+
203        x = x.reshape(seq_len, batch_size, -1)
@@ -683,7 +683,7 @@ M834 80h400000v40h-400000z">
207        return self.output(x)
+
206        return self.output(x)

Transformer Encoder and Decoder Models

-

Open In Colab Open In Comet

+

Open In Colab

-
14import math
-15
-16import torch
-17import torch.nn as nn
-18
-19from labml_nn.utils import clone_module_list
-20from .feed_forward import FeedForward
-21from .mha import MultiHeadAttention
-22from .positional_encoding import get_positional_encoding
+
13import math
+14
+15import torch
+16import torch.nn as nn
+17
+18from labml_nn.utils import clone_module_list
+19from .feed_forward import FeedForward
+20from .mha import MultiHeadAttention
+21from .positional_encoding import get_positional_encoding
@@ -100,7 +100,7 @@
-
25class EmbeddingsWithPositionalEncoding(nn.Module):
+
24class EmbeddingsWithPositionalEncoding(nn.Module):
@@ -111,11 +111,11 @@
-
32    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-33        super().__init__()
-34        self.linear = nn.Embedding(n_vocab, d_model)
-35        self.d_model = d_model
-36        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
+
31    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+32        super().__init__()
+33        self.linear = nn.Embedding(n_vocab, d_model)
+34        self.d_model = d_model
+35        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
@@ -126,9 +126,9 @@
-
38    def forward(self, x: torch.Tensor):
-39        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
-40        return self.linear(x) * math.sqrt(self.d_model) + pe
+
37    def forward(self, x: torch.Tensor):
+38        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
+39        return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -141,7 +141,7 @@
-
43class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
+
42class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
@@ -152,11 +152,11 @@
-
50    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-51        super().__init__()
-52        self.linear = nn.Embedding(n_vocab, d_model)
-53        self.d_model = d_model
-54        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
+
49    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+50        super().__init__()
+51        self.linear = nn.Embedding(n_vocab, d_model)
+52        self.d_model = d_model
+53        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
@@ -167,9 +167,9 @@
-
56    def forward(self, x: torch.Tensor):
-57        pe = self.positional_encodings[:x.shape[0]]
-58        return self.linear(x) * math.sqrt(self.d_model) + pe
+
55    def forward(self, x: torch.Tensor):
+56        pe = self.positional_encodings[:x.shape[0]]
+57        return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -184,7 +184,7 @@
-
61class TransformerLayer(nn.Module):
+
60class TransformerLayer(nn.Module):
@@ -205,12 +205,12 @@
-
79    def __init__(self, *,
-80                 d_model: int,
-81                 self_attn: MultiHeadAttention,
-82                 src_attn: MultiHeadAttention = None,
-83                 feed_forward: FeedForward,
-84                 dropout_prob: float):
+
78    def __init__(self, *,
+79                 d_model: int,
+80                 self_attn: MultiHeadAttention,
+81                 src_attn: MultiHeadAttention = None,
+82                 feed_forward: FeedForward,
+83                 dropout_prob: float):
@@ -221,16 +221,16 @@
-
92        super().__init__()
-93        self.size = d_model
-94        self.self_attn = self_attn
-95        self.src_attn = src_attn
-96        self.feed_forward = feed_forward
-97        self.dropout = nn.Dropout(dropout_prob)
-98        self.norm_self_attn = nn.LayerNorm([d_model])
-99        if self.src_attn is not None:
-100            self.norm_src_attn = nn.LayerNorm([d_model])
-101        self.norm_ff = nn.LayerNorm([d_model])
+
91        super().__init__()
+92        self.size = d_model
+93        self.self_attn = self_attn
+94        self.src_attn = src_attn
+95        self.feed_forward = feed_forward
+96        self.dropout = nn.Dropout(dropout_prob)
+97        self.norm_self_attn = nn.LayerNorm([d_model])
+98        if self.src_attn is not None:
+99            self.norm_src_attn = nn.LayerNorm([d_model])
+100        self.norm_ff = nn.LayerNorm([d_model])
@@ -242,7 +242,7 @@
-
103        self.is_save_ff_input = False
+
102        self.is_save_ff_input = False
@@ -253,11 +253,11 @@
-
105    def forward(self, *,
-106                x: torch.Tensor,
-107                mask: torch.Tensor,
-108                src: torch.Tensor = None,
-109                src_mask: torch.Tensor = None):
+
104    def forward(self, *,
+105                x: torch.Tensor,
+106                mask: torch.Tensor,
+107                src: torch.Tensor = None,
+108                src_mask: torch.Tensor = None):
@@ -269,7 +269,7 @@
-
111        z = self.norm_self_attn(x)
+
110        z = self.norm_self_attn(x)
@@ -281,7 +281,7 @@
-
113        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
+
112        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
@@ -293,7 +293,7 @@
-
115        x = x + self.dropout(self_attn)
+
114        x = x + self.dropout(self_attn)
@@ -305,7 +305,7 @@
-
120        if src is not None:
+
119        if src is not None:
@@ -317,7 +317,7 @@
-
122            z = self.norm_src_attn(x)
+
121            z = self.norm_src_attn(x)
@@ -329,7 +329,7 @@
-
124            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
+
123            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
@@ -341,7 +341,7 @@
-
126            x = x + self.dropout(attn_src)
+
125            x = x + self.dropout(attn_src)
@@ -353,7 +353,7 @@
-
129        z = self.norm_ff(x)
+
128        z = self.norm_ff(x)
@@ -365,8 +365,8 @@
-
131        if self.is_save_ff_input:
-132            self.ff_input = z.clone()
+
130        if self.is_save_ff_input:
+131            self.ff_input = z.clone()
@@ -378,7 +378,7 @@
-
134        ff = self.feed_forward(z)
+
133        ff = self.feed_forward(z)
@@ -390,9 +390,9 @@
-
136        x = x + self.dropout(ff)
-137
-138        return x
+
135        x = x + self.dropout(ff)
+136
+137        return x
@@ -405,7 +405,7 @@
-
141class Encoder(nn.Module):
+
140class Encoder(nn.Module):
@@ -416,8 +416,8 @@
-
148    def __init__(self, layer: TransformerLayer, n_layers: int):
-149        super().__init__()
+
147    def __init__(self, layer: TransformerLayer, n_layers: int):
+148        super().__init__()
@@ -429,7 +429,7 @@
-
151        self.layers = clone_module_list(layer, n_layers)
+
150        self.layers = clone_module_list(layer, n_layers)
@@ -441,7 +441,7 @@
-
153        self.norm = nn.LayerNorm([layer.size])
+
152        self.norm = nn.LayerNorm([layer.size])
@@ -452,7 +452,7 @@
-
155    def forward(self, x: torch.Tensor, mask: torch.Tensor):
+
154    def forward(self, x: torch.Tensor, mask: torch.Tensor):
@@ -464,8 +464,8 @@
-
157        for layer in self.layers:
-158            x = layer(x=x, mask=mask)
+
156        for layer in self.layers:
+157            x = layer(x=x, mask=mask)
@@ -477,7 +477,7 @@
-
160        return self.norm(x)
+
159        return self.norm(x)
@@ -490,7 +490,7 @@
-
163class Decoder(nn.Module):
+
162class Decoder(nn.Module):
@@ -501,8 +501,8 @@
-
170    def __init__(self, layer: TransformerLayer, n_layers: int):
-171        super().__init__()
+
169    def __init__(self, layer: TransformerLayer, n_layers: int):
+170        super().__init__()
@@ -514,7 +514,7 @@
-
173        self.layers = clone_module_list(layer, n_layers)
+
172        self.layers = clone_module_list(layer, n_layers)
@@ -526,7 +526,7 @@
-
175        self.norm = nn.LayerNorm([layer.size])
+
174        self.norm = nn.LayerNorm([layer.size])
@@ -537,7 +537,7 @@
-
177    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+
176    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
@@ -549,8 +549,8 @@
-
179        for layer in self.layers:
-180            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
+
178        for layer in self.layers:
+179            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
@@ -562,7 +562,7 @@
-
182        return self.norm(x)
+
181        return self.norm(x)
@@ -577,7 +577,7 @@
-
185class Generator(nn.Module):
+
184class Generator(nn.Module):
@@ -588,9 +588,9 @@
-
195    def __init__(self, n_vocab: int, d_model: int):
-196        super().__init__()
-197        self.projection = nn.Linear(d_model, n_vocab)
+
194    def __init__(self, n_vocab: int, d_model: int):
+195        super().__init__()
+196        self.projection = nn.Linear(d_model, n_vocab)
@@ -601,8 +601,8 @@
-
199    def forward(self, x):
-200        return self.projection(x)
+
198    def forward(self, x):
+199        return self.projection(x)
@@ -615,7 +615,7 @@
-
203class EncoderDecoder(nn.Module):
+
202class EncoderDecoder(nn.Module):
@@ -626,13 +626,13 @@
-
210    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
-211        super().__init__()
-212        self.encoder = encoder
-213        self.decoder = decoder
-214        self.src_embed = src_embed
-215        self.tgt_embed = tgt_embed
-216        self.generator = generator
+
209    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
+210        super().__init__()
+211        self.encoder = encoder
+212        self.decoder = decoder
+213        self.src_embed = src_embed
+214        self.tgt_embed = tgt_embed
+215        self.generator = generator
@@ -644,9 +644,9 @@
-
220        for p in self.parameters():
-221            if p.dim() > 1:
-222                nn.init.xavier_uniform_(p)
+
219        for p in self.parameters():
+220            if p.dim() > 1:
+221                nn.init.xavier_uniform_(p)
@@ -657,7 +657,7 @@
-
224    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+
223    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
@@ -669,7 +669,7 @@
-
226        enc = self.encode(src, src_mask)
+
225        enc = self.encode(src, src_mask)
@@ -681,7 +681,7 @@
-
228        return self.decode(enc, src_mask, tgt, tgt_mask)
+
227        return self.decode(enc, src_mask, tgt, tgt_mask)
@@ -692,8 +692,8 @@
-
230    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
-231        return self.encoder(self.src_embed(src), src_mask)
+
229    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
+230        return self.encoder(self.src_embed(src), src_mask)
@@ -704,8 +704,8 @@
-
233    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
-234        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
+
232    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
+233        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
-
45    experiment.create(name="roper_addition", comment="rotary value 7", writers={'screen', 'labml', 'comet'})
+
45    experiment.create(name="roper_addition", comment="rotary value 7", writers={'screen', 'labml'})
diff --git a/docs/transformers/rope/value_pe/experiment.html b/docs/transformers/rope/value_pe/experiment.html index a9fcde20..012d96c2 100644 --- a/docs/transformers/rope/value_pe/experiment.html +++ b/docs/transformers/rope/value_pe/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/rope/value_pe/index.html b/docs/transformers/rope/value_pe/index.html index 894aa595..c9cb99d6 100644 --- a/docs/transformers/rope/value_pe/index.html +++ b/docs/transformers/rope/value_pe/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/switch/experiment.html b/docs/transformers/switch/experiment.html index da78d737..4b9e4069 100644 --- a/docs/transformers/switch/experiment.html +++ b/docs/transformers/switch/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/switch/index.html b/docs/transformers/switch/index.html index 9562fc66..bb3f6e75 100644 --- a/docs/transformers/switch/index.html +++ b/docs/transformers/switch/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/switch/readme.html b/docs/transformers/switch/readme.html index 31dedc63..94b76e85 100644 --- a/docs/transformers/switch/readme.html +++ b/docs/transformers/switch/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/utils.html b/docs/transformers/utils.html index f03a78f4..5e2332ab 100644 --- a/docs/transformers/utils.html +++ b/docs/transformers/utils.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/vit/experiment.html b/docs/transformers/vit/experiment.html index 8e0ab9ed..c8128d6d 100644 --- a/docs/transformers/vit/experiment.html +++ b/docs/transformers/vit/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/vit/index.html b/docs/transformers/vit/index.html index 6696998f..0ebd074e 100644 --- a/docs/transformers/vit/index.html +++ b/docs/transformers/vit/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/vit/readme.html b/docs/transformers/vit/readme.html index a18dfa04..d5d52600 100644 --- a/docs/transformers/vit/readme.html +++ b/docs/transformers/vit/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/xl/experiment.html b/docs/transformers/xl/experiment.html index c34bcb20..08ff84dd 100644 --- a/docs/transformers/xl/experiment.html +++ b/docs/transformers/xl/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/xl/index.html b/docs/transformers/xl/index.html index 5d012e86..db13a79e 100644 --- a/docs/transformers/xl/index.html +++ b/docs/transformers/xl/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/xl/readme.html b/docs/transformers/xl/readme.html index 5174f2b9..402f14b7 100644 --- a/docs/transformers/xl/readme.html +++ b/docs/transformers/xl/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/transformers/xl/relative_mha.html b/docs/transformers/xl/relative_mha.html index 52dcd665..f617d7da 100644 --- a/docs/transformers/xl/relative_mha.html +++ b/docs/transformers/xl/relative_mha.html @@ -1,5 +1,5 @@ - + diff --git a/docs/uncertainty/evidence/experiment.html b/docs/uncertainty/evidence/experiment.html index 3a759daa..f52e0cba 100644 --- a/docs/uncertainty/evidence/experiment.html +++ b/docs/uncertainty/evidence/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/uncertainty/evidence/index.html b/docs/uncertainty/evidence/index.html index 2e340588..74fa8a00 100644 --- a/docs/uncertainty/evidence/index.html +++ b/docs/uncertainty/evidence/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/uncertainty/evidence/readme.html b/docs/uncertainty/evidence/readme.html index a4209cc2..96913440 100644 --- a/docs/uncertainty/evidence/readme.html +++ b/docs/uncertainty/evidence/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/uncertainty/index.html b/docs/uncertainty/index.html index 94af4c96..c885d6c2 100644 --- a/docs/uncertainty/index.html +++ b/docs/uncertainty/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/uncertainty/readme.html b/docs/uncertainty/readme.html index 4cb20dd5..cba06c9c 100644 --- a/docs/uncertainty/readme.html +++ b/docs/uncertainty/readme.html @@ -1,5 +1,5 @@ - + diff --git a/docs/unet/carvana.html b/docs/unet/carvana.html index b860d11c..83cff98e 100644 --- a/docs/unet/carvana.html +++ b/docs/unet/carvana.html @@ -1,5 +1,5 @@ - + diff --git a/docs/unet/experiment.html b/docs/unet/experiment.html index d65e2aba..26cfabe3 100644 --- a/docs/unet/experiment.html +++ b/docs/unet/experiment.html @@ -1,5 +1,5 @@ - + diff --git a/docs/unet/index.html b/docs/unet/index.html index 628b9b67..e3476026 100644 --- a/docs/unet/index.html +++ b/docs/unet/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/utils/index.html b/docs/utils/index.html index 5a2b4386..3f12a5a2 100644 --- a/docs/utils/index.html +++ b/docs/utils/index.html @@ -1,5 +1,5 @@ - + diff --git a/docs/utils/tokenizer.html b/docs/utils/tokenizer.html index a473f647..286084d7 100644 --- a/docs/utils/tokenizer.html +++ b/docs/utils/tokenizer.html @@ -1,5 +1,5 @@ - + diff --git a/labml_nn/diffusion/ddpm/__init__.py b/labml_nn/diffusion/ddpm/__init__.py index b678fb37..d78f6ad4 100644 --- a/labml_nn/diffusion/ddpm/__init__.py +++ b/labml_nn/diffusion/ddpm/__init__.py @@ -144,7 +144,7 @@ That is, we are training to predict the noise. ### Simplified loss -$$L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert +$$L_{\text{simple}}(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert \epsilon - \textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t) \bigg\Vert^2 \Bigg]$$ @@ -265,7 +265,7 @@ class DenoiseDiffusion: """ #### Simplified Loss - $$L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert + $$L_{\text{simple}}(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert \epsilon - \textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t) \bigg\Vert^2 \Bigg]$$ """