diff --git a/docs/activations/fta/experiment.html b/docs/activations/fta/experiment.html index 3d0d2ce3..8f2d2e91 100644 --- a/docs/activations/fta/experiment.html +++ b/docs/activations/fta/experiment.html @@ -70,9 +70,9 @@ #

Fuzzy Tiling Activation Experiment

+

Open In Colab Open In Comet

Here we train a transformer that uses Fuzzy Tiling Activation in the Feed-Forward Network. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.

However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.

-

Open In Colab Open In Comet

diff --git a/docs/activations/fta/index.html b/docs/activations/fta/index.html index 32bee1ea..2c38a4ab 100644 --- a/docs/activations/fta/index.html +++ b/docs/activations/fta/index.html @@ -70,6 +70,7 @@ #

Fuzzy Tiling Activations (FTA)

+

Open In Colab Open In Comet

This is a PyTorch implementation/tutorial of Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online.

Fuzzy tiling activations are a form of sparse activations based on binning.

Binning is classification of a scalar value into a bin based on intervals. One problem with binning is that it gives zero gradients for most values (except at the boundary of bins). The other is that binning loses precision if the bin intervals are large.

@@ -89,7 +90,6 @@

FTA uses this to create soft boundaries between bins.

Here's a simple experiment that uses FTA in a transformer.

-

Open In Colab Open In Comet

diff --git a/docs/capsule_networks/index.html b/docs/capsule_networks/index.html index 6f7bfa47..4cd28402 100644 --- a/docs/capsule_networks/index.html +++ b/docs/capsule_networks/index.html @@ -75,15 +75,15 @@

This file holds the implementations of the core modules of Capsule Networks.

I used jindongwang/Pytorch-CapsuleNet to clarify some confusions I had with the paper.

Here's a notebook for training a Capsule Network on MNIST dataset.

-

Open In Colab View Run Open In Comet

+

Open In Colab View Run

-
34import torch.nn as nn
-35import torch.nn.functional as F
-36import torch.utils.data
-37
-38from labml_helpers.module import Module
+
33import torch.nn as nn
+34import torch.nn.functional as F
+35import torch.utils.data
+36
+37from labml_helpers.module import Module
@@ -98,7 +98,7 @@
-
41class Squash(Module):
+
40class Squash(Module):
@@ -109,9 +109,9 @@
-
56    def __init__(self, epsilon=1e-8):
-57        super().__init__()
-58        self.epsilon = epsilon
+
55    def __init__(self, epsilon=1e-8):
+56        super().__init__()
+57        self.epsilon = epsilon
@@ -125,7 +125,7 @@
-
60    def forward(self, s: torch.Tensor):
+
59    def forward(self, s: torch.Tensor):
@@ -137,7 +137,7 @@
-
66        s2 = (s ** 2).sum(dim=-1, keepdims=True)
+
65        s2 = (s ** 2).sum(dim=-1, keepdims=True)
@@ -159,7 +159,7 @@ M1001 80h400000v40h-400000z">
72        return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))
+
71        return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))
@@ -173,7 +173,7 @@ M1001 80h400000v40h-400000z">
75class Router(Module):
+
74class Router(Module):
@@ -191,7 +191,7 @@ M1001 80h400000v40h-400000z">
86    def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
+
85    def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
@@ -202,12 +202,12 @@ M1001 80h400000v40h-400000z">
93        super().__init__()
-94        self.in_caps = in_caps
-95        self.out_caps = out_caps
-96        self.iterations = iterations
-97        self.softmax = nn.Softmax(dim=1)
-98        self.squash = Squash()
+
92        super().__init__()
+93        self.in_caps = in_caps
+94        self.out_caps = out_caps
+95        self.iterations = iterations
+96        self.softmax = nn.Softmax(dim=1)
+97        self.squash = Squash()
@@ -219,7 +219,7 @@ M1001 80h400000v40h-400000z">
102        self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)
+
101        self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)
@@ -233,7 +233,7 @@ M1001 80h400000v40h-400000z">
104    def forward(self, u: torch.Tensor):
+
103    def forward(self, u: torch.Tensor):
@@ -245,7 +245,7 @@ M1001 80h400000v40h-400000z">
113        u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)
+
112        u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)
@@ -257,9 +257,9 @@ M1001 80h400000v40h-400000z">
118        b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
-119
-120        v = None
+
117        b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
+118
+119        v = None
@@ -271,7 +271,7 @@ M1001 80h400000v40h-400000z">
123        for i in range(self.iterations):
+
122        for i in range(self.iterations):
@@ -283,7 +283,7 @@ M1001 80h400000v40h-400000z">
125            c = self.softmax(b)
+
124            c = self.softmax(b)
@@ -295,7 +295,7 @@ M1001 80h400000v40h-400000z">
127            s = torch.einsum('bij,bijm->bjm', c, u_hat)
+
126            s = torch.einsum('bij,bijm->bjm', c, u_hat)
@@ -307,7 +307,7 @@ M1001 80h400000v40h-400000z">
129            v = self.squash(s)
+
128            v = self.squash(s)
@@ -319,7 +319,7 @@ M1001 80h400000v40h-400000z">
131            a = torch.einsum('bjm,bijm->bij', v, u_hat)
+
130            a = torch.einsum('bjm,bijm->bij', v, u_hat)
@@ -331,9 +331,9 @@ M1001 80h400000v40h-400000z">
133            b = b + a
-134
-135        return v
+
132            b = b + a
+133
+134        return v
@@ -349,7 +349,7 @@ M1001 80h400000v40h-400000z">
138class MarginLoss(Module):
+
137class MarginLoss(Module):
@@ -360,13 +360,13 @@ M1001 80h400000v40h-400000z">
158    def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
-159        super().__init__()
-160
-161        self.m_negative = m_negative
-162        self.m_positive = m_positive
-163        self.lambda_ = lambda_
-164        self.n_labels = n_labels
+
157    def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
+158        super().__init__()
+159
+160        self.m_negative = m_negative
+161        self.m_positive = m_positive
+162        self.lambda_ = lambda_
+163        self.n_labels = n_labels
@@ -383,7 +383,7 @@ M1001 80h400000v40h-400000z">
166    def forward(self, v: torch.Tensor, labels: torch.Tensor):
+
165    def forward(self, v: torch.Tensor, labels: torch.Tensor):
@@ -395,7 +395,7 @@ M1001 80h400000v40h-400000z">
174        v_norm = torch.sqrt((v ** 2).sum(dim=-1))
+
173        v_norm = torch.sqrt((v ** 2).sum(dim=-1))
@@ -409,7 +409,7 @@ M1001 80h400000v40h-400000z">
178        labels = torch.eye(self.n_labels, device=labels.device)[labels]
+
177        labels = torch.eye(self.n_labels, device=labels.device)[labels]
@@ -423,8 +423,8 @@ M1001 80h400000v40h-400000z">
184        loss = labels * F.relu(self.m_positive - v_norm) + \
-185               self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)
+
183        loss = labels * F.relu(self.m_positive - v_norm) + \
+184               self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)
@@ -436,7 +436,7 @@ M1001 80h400000v40h-400000z">
188        return loss.sum(dim=-1).mean()
+
187        return loss.sum(dim=-1).mean()
-
16from typing import Any
-17
-18import torch.nn as nn
-19import torch.nn.functional as F
-20import torch.utils.data
-21
-22from labml import experiment, tracker
-23from labml.configs import option
-24from labml_helpers.datasets.mnist import MNISTConfigs
-25from labml_helpers.metrics.accuracy import AccuracyDirect
-26from labml_helpers.module import Module
-27from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
-28from labml_nn.capsule_networks import Squash, Router, MarginLoss
+
14from typing import Any
+15
+16import torch.nn as nn
+17import torch.nn.functional as F
+18import torch.utils.data
+19
+20from labml import experiment, tracker
+21from labml.configs import option
+22from labml_helpers.datasets.mnist import MNISTConfigs
+23from labml_helpers.metrics.accuracy import AccuracyDirect
+24from labml_helpers.module import Module
+25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
+26from labml_nn.capsule_networks import Squash, Router, MarginLoss
@@ -99,7 +98,7 @@
-
31class MNISTCapsuleNetworkModel(Module):
+
29class MNISTCapsuleNetworkModel(Module):
@@ -110,8 +109,8 @@
-
36    def __init__(self):
-37        super().__init__()
+
34    def __init__(self):
+35        super().__init__()
@@ -123,7 +122,7 @@
-
39        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
+
37        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
@@ -135,8 +134,8 @@
-
45        self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
-46        self.squash = Squash()
+
43        self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
+44        self.squash = Squash()
@@ -148,7 +147,7 @@
-
52        self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
+
50        self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
@@ -160,14 +159,14 @@
-
57        self.decoder = nn.Sequential(
-58            nn.Linear(16 * 10, 512),
+            
55        self.decoder = nn.Sequential(
+56            nn.Linear(16 * 10, 512),
+57            nn.ReLU(),
+58            nn.Linear(512, 1024),
 59            nn.ReLU(),
-60            nn.Linear(512, 1024),
-61            nn.ReLU(),
-62            nn.Linear(1024, 784),
-63            nn.Sigmoid()
-64        )
+60 nn.Linear(1024, 784), +61 nn.Sigmoid() +62 )
@@ -181,7 +180,7 @@
-
66    def forward(self, data: torch.Tensor):
+
64    def forward(self, data: torch.Tensor):
@@ -194,7 +193,7 @@
-
72        x = F.relu(self.conv1(data))
+
70        x = F.relu(self.conv1(data))
@@ -207,7 +206,7 @@
-
76        x = self.conv2(x)
+
74        x = self.conv2(x)
@@ -219,7 +218,7 @@
-
79        caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
+
77        caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
@@ -231,7 +230,7 @@
-
81        caps = self.squash(caps)
+
79        caps = self.squash(caps)
@@ -244,7 +243,7 @@
-
84        caps = self.digit_capsules(caps)
+
82        caps = self.digit_capsules(caps)
@@ -256,7 +255,7 @@
-
87        with torch.no_grad():
+
85        with torch.no_grad():
@@ -268,7 +267,7 @@
-
89            pred = (caps ** 2).sum(-1).argmax(-1)
+
87            pred = (caps ** 2).sum(-1).argmax(-1)
@@ -280,7 +279,7 @@
-
91            mask = torch.eye(10, device=data.device)[pred]
+
89            mask = torch.eye(10, device=data.device)[pred]
@@ -292,7 +291,7 @@
-
95        reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))
+
93        reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))
@@ -304,9 +303,9 @@
-
97        reconstructions = reconstructions.view(-1, 1, 28, 28)
-98
-99        return caps, reconstructions, pred
+
95        reconstructions = reconstructions.view(-1, 1, 28, 28)
+96
+97        return caps, reconstructions, pred
@@ -318,7 +317,7 @@
-
102class Configs(MNISTConfigs, SimpleTrainValidConfigs):
+
100class Configs(MNISTConfigs, SimpleTrainValidConfigs):
@@ -329,11 +328,11 @@
-
106    epochs: int = 10
-107    model: nn.Module = 'capsule_network_model'
-108    reconstruction_loss = nn.MSELoss()
-109    margin_loss = MarginLoss(n_labels=10)
-110    accuracy = AccuracyDirect()
+
104    epochs: int = 10
+105    model: nn.Module = 'capsule_network_model'
+106    reconstruction_loss = nn.MSELoss()
+107    margin_loss = MarginLoss(n_labels=10)
+108    accuracy = AccuracyDirect()
@@ -344,7 +343,7 @@
-
112    def init(self):
+
110    def init(self):
@@ -356,8 +355,8 @@
-
114        tracker.set_scalar('loss.*', True)
-115        tracker.set_scalar('accuracy.*', True)
+
112        tracker.set_scalar('loss.*', True)
+113        tracker.set_scalar('accuracy.*', True)
@@ -369,7 +368,7 @@
-
118        self.state_modules = [self.accuracy]
+
116        self.state_modules = [self.accuracy]
@@ -381,7 +380,7 @@
-
120    def step(self, batch: Any, batch_idx: BatchIndex):
+
118    def step(self, batch: Any, batch_idx: BatchIndex):
@@ -393,7 +392,7 @@
-
125        self.model.train(self.mode.is_train)
+
123        self.model.train(self.mode.is_train)
@@ -405,7 +404,7 @@
-
128        data, target = batch[0].to(self.device), batch[1].to(self.device)
+
126        data, target = batch[0].to(self.device), batch[1].to(self.device)
@@ -417,8 +416,8 @@
-
131        if self.mode.is_train:
-132            tracker.add_global_step(len(data))
+
129        if self.mode.is_train:
+130            tracker.add_global_step(len(data))
@@ -430,7 +429,7 @@
-
135        with self.mode.update(is_log_activations=batch_idx.is_last):
+
133        with self.mode.update(is_log_activations=batch_idx.is_last):
@@ -442,7 +441,7 @@
-
137            caps, reconstructions, pred = self.model(data)
+
135            caps, reconstructions, pred = self.model(data)
@@ -454,8 +453,8 @@
-
140        loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
-141        tracker.add("loss.", loss)
+
138        loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
+139        tracker.add("loss.", loss)
@@ -467,12 +466,12 @@
-
144        self.accuracy(pred, target)
-145
-146        if self.mode.is_train:
-147            loss.backward()
-148
-149            self.optimizer.step()
+
142        self.accuracy(pred, target)
+143
+144        if self.mode.is_train:
+145            loss.backward()
+146
+147            self.optimizer.step()
@@ -484,11 +483,11 @@
-
151            if batch_idx.is_last:
-152                tracker.add('model', self.model)
-153            self.optimizer.zero_grad()
-154
-155            tracker.save()
+
149            if batch_idx.is_last:
+150                tracker.add('model', self.model)
+151            self.optimizer.zero_grad()
+152
+153            tracker.save()
@@ -500,8 +499,8 @@
-
158@option(Configs.model)
-159def capsule_network_model(c: Configs):
+
156@option(Configs.model)
+157def capsule_network_model(c: Configs):
@@ -512,7 +511,7 @@
-
161    return MNISTCapsuleNetworkModel().to(c.device)
+
159    return MNISTCapsuleNetworkModel().to(c.device)
@@ -524,7 +523,7 @@
-
164def main():
+
162def main():
@@ -535,19 +534,19 @@
-
168    experiment.create(name='capsule_network_mnist')
-169    conf = Configs()
-170    experiment.configs(conf, {'optimizer.optimizer': 'Adam',
-171                              'optimizer.learning_rate': 1e-3})
+            
166    experiment.create(name='capsule_network_mnist')
+167    conf = Configs()
+168    experiment.configs(conf, {'optimizer.optimizer': 'Adam',
+169                              'optimizer.learning_rate': 1e-3})
+170
+171    experiment.add_pytorch_models({'model': conf.model})
 172
-173    experiment.add_pytorch_models({'model': conf.model})
-174
-175    with experiment.start():
-176        conf.run()
-177
-178
-179if __name__ == '__main__':
-180    main()
+173 with experiment.start(): +174 conf.run() +175 +176 +177if __name__ == '__main__': +178 main()
109        tracker.set_scalar("accuracy.*", True)
-110        tracker.set_scalar("loss.*", True)
+110 tracker.set_scalar("loss.*", True) +111 tracker.set_text("sampled", False)
@@ -414,7 +415,7 @@
-
112        hook_model_outputs(self.mode, self.model, 'model')
+
113        hook_model_outputs(self.mode, self.model, 'model')
@@ -426,7 +427,7 @@
-
117        self.state_modules = [self.accuracy]
+
118        self.state_modules = [self.accuracy]
@@ -438,7 +439,7 @@
-
119    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
+
120    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
@@ -449,7 +450,7 @@
-
121        pass
+
122        pass
@@ -461,7 +462,7 @@
-
123    def step(self, batch: any, batch_idx: BatchIndex):
+
124    def step(self, batch: any, batch_idx: BatchIndex):
@@ -473,7 +474,7 @@
-
129        self.model.train(self.mode.is_train)
+
130        self.model.train(self.mode.is_train)
@@ -485,7 +486,7 @@
-
132        data, target = batch[0].to(self.device), batch[1].to(self.device)
+
133        data, target = batch[0].to(self.device), batch[1].to(self.device)
@@ -497,8 +498,8 @@
-
135        if self.mode.is_train:
-136            tracker.add_global_step(data.shape[0] * data.shape[1])
+
136        if self.mode.is_train:
+137            tracker.add_global_step(data.shape[0] * data.shape[1])
@@ -510,7 +511,7 @@
-
139        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
+
140        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
@@ -522,7 +523,7 @@
-
143            output, *_ = self.model(data)
+
144            output, *_ = self.model(data)
@@ -534,8 +535,8 @@
-
146        loss = self.loss_func(output, target)
-147        tracker.add("loss.", loss)
+
147        loss = self.loss_func(output, target)
+148        tracker.add("loss.", loss)
@@ -547,10 +548,10 @@
-
150        self.accuracy(output, target)
-151        self.accuracy.track()
-152
-153        self.other_metrics(output, target)
+
151        self.accuracy(output, target)
+152        self.accuracy.track()
+153
+154        self.other_metrics(output, target)
@@ -562,7 +563,7 @@
-
156        if self.mode.is_train:
+
157        if self.mode.is_train:
@@ -574,7 +575,7 @@
-
158            loss.backward()
+
159            loss.backward()
@@ -586,7 +587,7 @@
-
160            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
+
161            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
@@ -598,7 +599,7 @@
-
162            self.optimizer.step()
+
163            self.optimizer.step()
@@ -610,8 +611,8 @@
-
164            if batch_idx.is_last and self.is_log_model_params_grads:
-165                tracker.add('model', self.model)
+
165            if batch_idx.is_last and self.is_log_model_params_grads:
+166                tracker.add('model', self.model)
@@ -623,7 +624,7 @@
-
167            self.optimizer.zero_grad()
+
168            self.optimizer.zero_grad()
@@ -635,7 +636,7 @@
-
170        tracker.save()
+
171        tracker.save()
@@ -647,7 +648,7 @@
-
172    def sample(self):
+
173    def sample(self):
@@ -659,7 +660,7 @@
-
178        prompt = self.prompt
+
179        prompt = self.prompt
@@ -671,7 +672,7 @@
-
180        log = [(prompt, Text.subtle)]
+
181        log = [(prompt, Text.subtle)]
@@ -683,7 +684,7 @@
-
182        for i in monit.iterate('Sample', 25):
+
183        for i in monit.iterate('Sample', 25):
@@ -695,8 +696,8 @@
-
184            data = self.text.text_to_i(prompt).unsqueeze(-1)
-185            data = data.to(self.device)
+
185            data = self.text.text_to_i(prompt).unsqueeze(-1)
+186            data = data.to(self.device)
@@ -708,7 +709,7 @@
-
187            output, *_ = self.model(data)
+
188            output, *_ = self.model(data)
@@ -720,7 +721,7 @@
-
189            output = output.argmax(dim=-1).squeeze()
+
190            output = output.argmax(dim=-1).squeeze()
@@ -732,7 +733,7 @@
-
191            prompt += self.prompt_separator + self.text.itos[output[-1]]
+
192            prompt += self.prompt_separator + self.text.itos[output[-1]]
@@ -744,7 +745,9 @@
-
193            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
+
194            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
+195
+196        tracker.add({'sampled': prompt})
@@ -756,7 +759,7 @@
-
196        logger.log(log)
+
198        logger.log(log)
@@ -768,8 +771,8 @@
-
199@option(NLPAutoRegressionConfigs.optimizer)
-200def _optimizer(c: NLPAutoRegressionConfigs):
+
201@option(NLPAutoRegressionConfigs.optimizer)
+202def _optimizer(c: NLPAutoRegressionConfigs):
@@ -780,12 +783,12 @@
-
205    optimizer = OptimizerConfigs()
-206    optimizer.parameters = c.model.parameters()
-207    optimizer.optimizer = 'Adam'
-208    optimizer.d_model = c.d_model
-209
-210    return optimizer
+
207    optimizer = OptimizerConfigs()
+208    optimizer.parameters = c.model.parameters()
+209    optimizer.optimizer = 'Adam'
+210    optimizer.d_model = c.d_model
+211
+212    return optimizer
@@ -797,8 +800,8 @@
-
213@option(NLPAutoRegressionConfigs.n_tokens)
-214def _n_tokens(c: NLPAutoRegressionConfigs):
+
215@option(NLPAutoRegressionConfigs.n_tokens)
+216def _n_tokens(c: NLPAutoRegressionConfigs):
@@ -809,7 +812,7 @@
-
218    return c.text.n_tokens
+
220    return c.text.n_tokens
@@ -824,8 +827,8 @@
-
221@option(NLPAutoRegressionConfigs.tokenizer)
-222def basic_english():
+
223@option(NLPAutoRegressionConfigs.tokenizer)
+224def basic_english():
@@ -836,8 +839,8 @@
-
236    from torchtext.data import get_tokenizer
-237    return get_tokenizer('basic_english')
+
238    from torchtext.data import get_tokenizer
+239    return get_tokenizer('basic_english')
@@ -849,7 +852,7 @@
-
240def character_tokenizer(x: str):
+
242def character_tokenizer(x: str):
@@ -860,7 +863,7 @@
-
244    return list(x)
+
246    return list(x)
@@ -872,8 +875,8 @@
-
247@option(NLPAutoRegressionConfigs.tokenizer)
-248def character():
+
249@option(NLPAutoRegressionConfigs.tokenizer)
+250def character():
@@ -884,7 +887,7 @@
-
252    return character_tokenizer
+
254    return character_tokenizer
@@ -897,8 +900,8 @@
-
255@option(NLPAutoRegressionConfigs.text)
-256def tiny_shakespeare(c: NLPAutoRegressionConfigs):
+
257@option(NLPAutoRegressionConfigs.text)
+258def tiny_shakespeare(c: NLPAutoRegressionConfigs):
@@ -909,10 +912,10 @@
-
262    return TextFileDataset(
-263        lab.get_data_path() / 'tiny_shakespeare.txt',
-264        c.tokenizer,
-265        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
+
264    return TextFileDataset(
+265        lab.get_data_path() / 'tiny_shakespeare.txt',
+266        c.tokenizer,
+267        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
@@ -924,8 +927,8 @@
-
268@option(NLPAutoRegressionConfigs.train_loader)
-269def sequential_train_loader(c: NLPAutoRegressionConfigs):
+
270@option(NLPAutoRegressionConfigs.train_loader)
+271def sequential_train_loader(c: NLPAutoRegressionConfigs):
@@ -936,10 +939,10 @@
-
273    return SequentialDataLoader(text=c.text.train,
-274                                dataset=c.text,
-275                                batch_size=c.batch_size,
-276                                seq_len=c.seq_len)
+
275    return SequentialDataLoader(text=c.text.train,
+276                                dataset=c.text,
+277                                batch_size=c.batch_size,
+278                                seq_len=c.seq_len)
@@ -951,8 +954,8 @@
-
279@option(NLPAutoRegressionConfigs.valid_loader)
-280def sequential_valid_loader(c: NLPAutoRegressionConfigs):
+
281@option(NLPAutoRegressionConfigs.valid_loader)
+282def sequential_valid_loader(c: NLPAutoRegressionConfigs):
@@ -963,10 +966,10 @@
-
284    return SequentialDataLoader(text=c.text.valid,
-285                                dataset=c.text,
-286                                batch_size=c.batch_size,
-287                                seq_len=c.seq_len)
+
286    return SequentialDataLoader(text=c.text.valid,
+287                                dataset=c.text,
+288                                batch_size=c.batch_size,
+289                                seq_len=c.seq_len)
@@ -980,7 +983,7 @@
-
290def transpose_batch(batch):
+
292def transpose_batch(batch):
@@ -991,7 +994,7 @@
-
298    transposed_data = list(zip(*batch))
+
300    transposed_data = list(zip(*batch))
@@ -1004,10 +1007,10 @@
-
300    src = torch.stack(transposed_data[0], dim=1)
-301    tgt = torch.stack(transposed_data[1], dim=1)
-302
-303    return src, tgt
+
302    src = torch.stack(transposed_data[0], dim=1)
+303    tgt = torch.stack(transposed_data[1], dim=1)
+304
+305    return src, tgt
@@ -1019,8 +1022,8 @@
-
306@option(NLPAutoRegressionConfigs.train_loader)
-307def shuffled_train_loader(c: NLPAutoRegressionConfigs):
+
308@option(NLPAutoRegressionConfigs.train_loader)
+309def shuffled_train_loader(c: NLPAutoRegressionConfigs):
@@ -1031,15 +1034,15 @@
-
311    dataset = SequentialUnBatchedDataset(text=c.text.train,
-312                                         dataset=c.text,
-313                                         seq_len=c.seq_len)
-314    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
-315
-316    return DataLoader(dataset,
-317                      batch_size=c.batch_size,
-318                      collate_fn=transpose_batch,
-319                      sampler=sampler)
+
313    dataset = SequentialUnBatchedDataset(text=c.text.train,
+314                                         dataset=c.text,
+315                                         seq_len=c.seq_len)
+316    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
+317
+318    return DataLoader(dataset,
+319                      batch_size=c.batch_size,
+320                      collate_fn=transpose_batch,
+321                      sampler=sampler)
@@ -1051,8 +1054,8 @@
-
322@option(NLPAutoRegressionConfigs.valid_loader)
-323def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
+
324@option(NLPAutoRegressionConfigs.valid_loader)
+325def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
@@ -1063,15 +1066,15 @@
-
327    dataset = SequentialUnBatchedDataset(text=c.text.valid,
-328                                         dataset=c.text,
-329                                         seq_len=c.seq_len)
-330    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
-331
-332    return DataLoader(dataset,
-333                      batch_size=c.batch_size,
-334                      collate_fn=transpose_batch,
-335                      sampler=sampler)
+
329    dataset = SequentialUnBatchedDataset(text=c.text.valid,
+330                                         dataset=c.text,
+331                                         seq_len=c.seq_len)
+332    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
+333
+334    return DataLoader(dataset,
+335                      batch_size=c.batch_size,
+336                      collate_fn=transpose_batch,
+337                      sampler=sampler)

DeepNorm Experiment

-

Open In Colab View Run Open In Comet

+

Open In Colab Open In Comet

-
15import copy
-16
-17import torch
-18import torch.nn as nn
-19
-20from labml import experiment
-21from labml.configs import option
-22from labml_helpers.module import Module
-23from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-24from labml_nn.normalization.deep_norm import DeepNormTransformerLayer
-25from labml_nn.transformers import MultiHeadAttention
-26from labml_nn.transformers.feed_forward import FeedForward
+
14import copy
+15
+16import torch
+17import torch.nn as nn
+18
+19from labml import experiment
+20from labml.configs import option
+21from labml_helpers.module import Module
+22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+23from labml_nn.normalization.deep_norm import DeepNormTransformerLayer
+24from labml_nn.transformers import MultiHeadAttention
+25from labml_nn.transformers.feed_forward import FeedForward
@@ -98,7 +98,7 @@
-
29class AutoregressiveTransformer(Module):
+
28class AutoregressiveTransformer(Module):
@@ -114,7 +114,7 @@
-
36    def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: DeepNormTransformerLayer):
+
35    def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: DeepNormTransformerLayer):
@@ -125,7 +125,7 @@
-
43        super().__init__()
+
42        super().__init__()
@@ -138,7 +138,7 @@
-
45        self.transformer = nn.Sequential(*[copy.deepcopy(layer) for _ in range(n_layers)])
+
44        self.transformer = nn.Sequential(*[copy.deepcopy(layer) for _ in range(n_layers)])
@@ -150,7 +150,7 @@
-
48        self.emb = nn.Embedding(n_tokens, d_model)
+
47        self.emb = nn.Embedding(n_tokens, d_model)
@@ -162,7 +162,7 @@
-
50        self.readout = nn.Linear(d_model, n_tokens)
+
49        self.readout = nn.Linear(d_model, n_tokens)
@@ -175,7 +175,7 @@
-
52    def forward(self, x: torch.Tensor):
+
51    def forward(self, x: torch.Tensor):
@@ -187,7 +187,7 @@
-
57        x = self.emb(x)
+
56        x = self.emb(x)
@@ -199,7 +199,7 @@
-
59        x = self.transformer(x)
+
58        x = self.transformer(x)
@@ -211,7 +211,7 @@
-
61        x = self.readout(x)
+
60        x = self.readout(x)
@@ -223,7 +223,7 @@
-
64        return x, None
+
63        return x, None
@@ -237,7 +237,7 @@
-
67class Configs(NLPAutoRegressionConfigs):
+
66class Configs(NLPAutoRegressionConfigs):
@@ -249,7 +249,7 @@
-
76    model: AutoregressiveTransformer
+
75    model: AutoregressiveTransformer
@@ -261,7 +261,7 @@
-
79    n_layers: int = 64
+
78    n_layers: int = 64
@@ -273,8 +273,8 @@
-
82    deep_norm_alpha: float
-83    deep_norm_beta: float
+
81    deep_norm_alpha: float
+82    deep_norm_beta: float
@@ -286,7 +286,7 @@
-
86    n_heads: int = 4
+
85    n_heads: int = 4
@@ -298,7 +298,7 @@
-
88    d_model: int = 64
+
87    d_model: int = 64
@@ -310,7 +310,7 @@
-
90    d_k: int = 16
+
89    d_k: int = 16
@@ -323,8 +323,8 @@
-
93@option(Configs.deep_norm_alpha)
-94def _deep_norm_alpha(c: Configs):
+
92@option(Configs.deep_norm_alpha)
+93def _deep_norm_alpha(c: Configs):
@@ -335,7 +335,7 @@
-
100    return (2. * c.n_layers) ** (1. / 4.)
+
99    return (2. * c.n_layers) ** (1. / 4.)
@@ -348,8 +348,8 @@
-
103@option(Configs.deep_norm_beta)
-104def _deep_norm_beta(c: Configs):
+
102@option(Configs.deep_norm_beta)
+103def _deep_norm_beta(c: Configs):
@@ -360,7 +360,7 @@
-
110    return (8. * c.n_layers) ** -(1. / 4.)
+
109    return (8. * c.n_layers) ** -(1. / 4.)
@@ -372,8 +372,8 @@
-
113@option(Configs.model)
-114def _model(c: Configs):
+
112@option(Configs.model)
+113def _model(c: Configs):
@@ -384,16 +384,16 @@
-
118    m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
-119                                  DeepNormTransformerLayer(d_model=c.d_model,
-120                                                           deep_norm_alpha=c.deep_norm_alpha,
-121                                                           deep_norm_beta=c.deep_norm_beta,
-122                                                           feed_forward=FeedForward(d_model=c.d_model,
-123                                                                                    d_ff=c.d_model * 4),
-124                                                           self_attn=MultiHeadAttention(c.n_heads, c.d_model,
-125                                                                                        dropout_prob=0.0)))
-126
-127    return m.to(c.device)
+
117    m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
+118                                  DeepNormTransformerLayer(d_model=c.d_model,
+119                                                           deep_norm_alpha=c.deep_norm_alpha,
+120                                                           deep_norm_beta=c.deep_norm_beta,
+121                                                           feed_forward=FeedForward(d_model=c.d_model,
+122                                                                                    d_ff=c.d_model * 4),
+123                                                           self_attn=MultiHeadAttention(c.n_heads, c.d_model,
+124                                                                                        dropout_prob=0.0)))
+125
+126    return m.to(c.device)
@@ -405,7 +405,7 @@
-
130def main():
+
129def main():
@@ -417,7 +417,7 @@
-
135    experiment.create(name="deep_norm", writers={'screen', 'web_api', 'comet'})
+
134    experiment.create(name="deep_norm", writers={'screen', 'web_api', 'comet'})
@@ -429,7 +429,7 @@
-
137    conf = Configs()
+
136    conf = Configs()
@@ -441,7 +441,7 @@
-
139    experiment.configs(conf, {
+
138    experiment.configs(conf, {
@@ -453,7 +453,7 @@
-
141        'tokenizer': 'character',
+
140        'tokenizer': 'character',
@@ -465,7 +465,7 @@
-
143        'prompt_separator': '',
+
142        'prompt_separator': '',
@@ -477,7 +477,7 @@
-
145        'prompt': 'It is ',
+
144        'prompt': 'It is ',
@@ -489,7 +489,7 @@
-
147        'text': 'tiny_shakespeare',
+
146        'text': 'tiny_shakespeare',
@@ -501,7 +501,7 @@
-
150        'seq_len': 256,
+
149        'seq_len': 256,
@@ -513,7 +513,7 @@
-
152        'epochs': 32,
+
151        'epochs': 32,
@@ -525,7 +525,7 @@
-
154        'batch_size': 16,
+
153        'batch_size': 16,
@@ -537,7 +537,7 @@
-
156        'inner_iterations': 10,
+
155        'inner_iterations': 10,
@@ -549,9 +549,9 @@
-
159        'optimizer.optimizer': 'Adam',
-160        'optimizer.learning_rate': 3e-4,
-161    })
+
158        'optimizer.optimizer': 'Adam',
+159        'optimizer.learning_rate': 3e-4,
+160    })
@@ -563,7 +563,7 @@
-
164    experiment.add_pytorch_models({'model': conf.model})
+
163    experiment.add_pytorch_models({'model': conf.model})
@@ -575,7 +575,7 @@
-
167    with experiment.start():
+
166    with experiment.start():
@@ -587,7 +587,7 @@
-
169        conf.run()
+
168        conf.run()
@@ -599,8 +599,8 @@
-
173if __name__ == '__main__':
-174    main()
+
172if __name__ == '__main__':
+173    main()

DeepNorm

+

Open In Colab Open In Comet

This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers.

The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.

The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.

@@ -85,19 +86,18 @@

Where is the number of layers in the encoder and is the number of layers in the decoder.

Refer to the paper for derivation.

Here is an experiment implementation that uses DeepNorm.

-

Open In Colab View Run Open In Comet

-
75from typing import Union, List
-76
-77import torch
-78from torch import nn, Size
-79
-80from labml_nn.normalization.layer_norm import LayerNorm
-81from labml_nn.transformers import MultiHeadAttention
-82from labml_nn.transformers.feed_forward import FeedForward
-83from labml_nn.transformers.utils import subsequent_mask
+
74from typing import Union, List
+75
+76import torch
+77from torch import nn, Size
+78
+79from labml_nn.normalization.layer_norm import LayerNorm
+80from labml_nn.transformers import MultiHeadAttention
+81from labml_nn.transformers.feed_forward import FeedForward
+82from labml_nn.transformers.utils import subsequent_mask
@@ -110,7 +110,7 @@
-
86class DeepNorm(nn.Module):
+
85class DeepNorm(nn.Module):
@@ -125,9 +125,9 @@
-
93    def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
-94                 eps: float = 1e-5,
-95                 elementwise_affine: bool = True):
+
92    def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
+93                 eps: float = 1e-5,
+94                 elementwise_affine: bool = True):
@@ -138,9 +138,9 @@
-
102        super().__init__()
-103
-104        self.alpha = alpha
+
101        super().__init__()
+102
+103        self.alpha = alpha
@@ -152,7 +152,7 @@
-
106        self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
+
105        self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
@@ -165,7 +165,7 @@
-
108    def forward(self, x: torch.Tensor, gx: torch.Tensor):
+
107    def forward(self, x: torch.Tensor, gx: torch.Tensor):
@@ -177,7 +177,7 @@
-
114        return x + self.alpha * gx
+
113        return x + self.alpha * gx
@@ -190,7 +190,7 @@
-
117class DeepNormTransformerLayer(nn.Module):
+
116class DeepNormTransformerLayer(nn.Module):
@@ -206,13 +206,13 @@
-
124    def __init__(self, *,
-125                 d_model: int,
-126                 self_attn: MultiHeadAttention,
-127                 feed_forward: FeedForward,
-128                 deep_norm_alpha: float,
-129                 deep_norm_beta: float,
-130                 ):
+
123    def __init__(self, *,
+124                 d_model: int,
+125                 self_attn: MultiHeadAttention,
+126                 feed_forward: FeedForward,
+127                 deep_norm_alpha: float,
+128                 deep_norm_beta: float,
+129                 ):
@@ -223,10 +223,10 @@
-
138        super().__init__()
-139
-140        self.self_attn = self_attn
-141        self.feed_forward = feed_forward
+
137        super().__init__()
+138
+139        self.self_attn = self_attn
+140        self.feed_forward = feed_forward
@@ -238,8 +238,8 @@
-
143        self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
-144        self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])
+
142        self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
+143        self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])
@@ -251,7 +251,7 @@
-
147        with torch.no_grad():
+
146        with torch.no_grad():
@@ -263,8 +263,8 @@
-
149            feed_forward.layer1.weight *= deep_norm_beta
-150            feed_forward.layer2.weight *= deep_norm_beta
+
148            feed_forward.layer1.weight *= deep_norm_beta
+149            feed_forward.layer2.weight *= deep_norm_beta
@@ -276,7 +276,7 @@
-
153            self_attn.value.linear.weight *= deep_norm_beta
+
152            self_attn.value.linear.weight *= deep_norm_beta
@@ -288,7 +288,7 @@
-
155            self_attn.output.weight *= deep_norm_beta
+
154            self_attn.output.weight *= deep_norm_beta
@@ -300,7 +300,7 @@
-
158        self.mask = None
+
157        self.mask = None
@@ -313,7 +313,7 @@
-
160    def forward(self, x: torch.Tensor):
+
159    def forward(self, x: torch.Tensor):
@@ -325,7 +325,7 @@
-
165        if self.mask is None or self.mask.size(0) != len(x):
+
164        if self.mask is None or self.mask.size(0) != len(x):
@@ -337,7 +337,7 @@
-
167            self.mask = subsequent_mask(len(x)).to(x.device)
+
166            self.mask = subsequent_mask(len(x)).to(x.device)
@@ -349,7 +349,7 @@
-
170        x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))
+
169        x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))
@@ -361,7 +361,7 @@
-
172        x = self.feed_forward_norm(x, self.feed_forward(x))
+
171        x = self.feed_forward_norm(x, self.feed_forward(x))
@@ -373,7 +373,7 @@
-
175        return x
+
174        return x