diff --git a/docs/activations/fta/experiment.html b/docs/activations/fta/experiment.html index 3d0d2ce3..8f2d2e91 100644 --- a/docs/activations/fta/experiment.html +++ b/docs/activations/fta/experiment.html @@ -70,9 +70,9 @@ #
Here we train a transformer that uses Fuzzy Tiling Activation in the Feed-Forward Network. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.
However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.
-This is a PyTorch implementation/tutorial of Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online.
Fuzzy tiling activations are a form of sparse activations based on binning.
Binning is classification of a scalar value into a bin based on intervals. One problem with binning is that it gives zero gradients for most values (except at the boundary of bins). The other is that binning loses precision if the bin intervals are large.
@@ -89,7 +90,6 @@FTA uses this to create soft boundaries between bins.
Here's a simple experiment that uses FTA in a transformer.
-This file holds the implementations of the core modules of Capsule Networks.
I used jindongwang/Pytorch-CapsuleNet to clarify some confusions I had with the paper.
Here's a notebook for training a Capsule Network on MNIST dataset.
- +34import torch.nn as nn
-35import torch.nn.functional as F
-36import torch.utils.data
-37
-38from labml_helpers.module import Module
33import torch.nn as nn
+34import torch.nn.functional as F
+35import torch.utils.data
+36
+37from labml_helpers.module import Module
41class Squash(Module):
40class Squash(Module):
56 def __init__(self, epsilon=1e-8):
-57 super().__init__()
-58 self.epsilon = epsilon
55 def __init__(self, epsilon=1e-8):
+56 super().__init__()
+57 self.epsilon = epsilon
60 def forward(self, s: torch.Tensor):
59 def forward(self, s: torch.Tensor):
66 s2 = (s ** 2).sum(dim=-1, keepdims=True)
65 s2 = (s ** 2).sum(dim=-1, keepdims=True)
72 return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))
71 return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))
75class Router(Module):
74class Router(Module):
86 def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
85 def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
93 super().__init__() -94 self.in_caps = in_caps -95 self.out_caps = out_caps -96 self.iterations = iterations -97 self.softmax = nn.Softmax(dim=1) -98 self.squash = Squash()
92 super().__init__()
+93 self.in_caps = in_caps
+94 self.out_caps = out_caps
+95 self.iterations = iterations
+96 self.softmax = nn.Softmax(dim=1)
+97 self.squash = Squash()
102 self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)
101 self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)
104 def forward(self, u: torch.Tensor):
103 def forward(self, u: torch.Tensor):
113 u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)
112 u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)
118 b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps) -119 -120 v = None
117 b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
+118
+119 v = None
123 for i in range(self.iterations):
122 for i in range(self.iterations):
125 c = self.softmax(b)
124 c = self.softmax(b)
127 s = torch.einsum('bij,bijm->bjm', c, u_hat)
126 s = torch.einsum('bij,bijm->bjm', c, u_hat)
129 v = self.squash(s)
128 v = self.squash(s)
131 a = torch.einsum('bjm,bijm->bij', v, u_hat)
130 a = torch.einsum('bjm,bijm->bij', v, u_hat)
133 b = b + a -134 -135 return v
132 b = b + a
+133
+134 return v
138class MarginLoss(Module):
137class MarginLoss(Module):
158 def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1): -159 super().__init__() -160 -161 self.m_negative = m_negative -162 self.m_positive = m_positive -163 self.lambda_ = lambda_ -164 self.n_labels = n_labels
157 def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
+158 super().__init__()
+159
+160 self.m_negative = m_negative
+161 self.m_positive = m_positive
+162 self.lambda_ = lambda_
+163 self.n_labels = n_labels
166 def forward(self, v: torch.Tensor, labels: torch.Tensor):
165 def forward(self, v: torch.Tensor, labels: torch.Tensor):
174 v_norm = torch.sqrt((v ** 2).sum(dim=-1))
173 v_norm = torch.sqrt((v ** 2).sum(dim=-1))
178 labels = torch.eye(self.n_labels, device=labels.device)[labels]
177 labels = torch.eye(self.n_labels, device=labels.device)[labels]
184 loss = labels * F.relu(self.m_positive - v_norm) + \ -185 self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)
183 loss = labels * F.relu(self.m_positive - v_norm) + \
+184 self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)
188 return loss.sum(dim=-1).mean()
187 return loss.sum(dim=-1).mean()
16from typing import Any
-17
-18import torch.nn as nn
-19import torch.nn.functional as F
-20import torch.utils.data
-21
-22from labml import experiment, tracker
-23from labml.configs import option
-24from labml_helpers.datasets.mnist import MNISTConfigs
-25from labml_helpers.metrics.accuracy import AccuracyDirect
-26from labml_helpers.module import Module
-27from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
-28from labml_nn.capsule_networks import Squash, Router, MarginLoss
14from typing import Any
+15
+16import torch.nn as nn
+17import torch.nn.functional as F
+18import torch.utils.data
+19
+20from labml import experiment, tracker
+21from labml.configs import option
+22from labml_helpers.datasets.mnist import MNISTConfigs
+23from labml_helpers.metrics.accuracy import AccuracyDirect
+24from labml_helpers.module import Module
+25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
+26from labml_nn.capsule_networks import Squash, Router, MarginLoss
31class MNISTCapsuleNetworkModel(Module):
29class MNISTCapsuleNetworkModel(Module):
36 def __init__(self):
-37 super().__init__()
34 def __init__(self):
+35 super().__init__()
39 self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
37 self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
45 self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
-46 self.squash = Squash()
43 self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
+44 self.squash = Squash()
52 self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
50 self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
57 self.decoder = nn.Sequential(
-58 nn.Linear(16 * 10, 512),
+ 55 self.decoder = nn.Sequential(
+56 nn.Linear(16 * 10, 512),
+57 nn.ReLU(),
+58 nn.Linear(512, 1024),
59 nn.ReLU(),
-60 nn.Linear(512, 1024),
-61 nn.ReLU(),
-62 nn.Linear(1024, 784),
-63 nn.Sigmoid()
-64 )
+60 nn.Linear(1024, 784),
+61 nn.Sigmoid()
+62 )
66 def forward(self, data: torch.Tensor):
64 def forward(self, data: torch.Tensor):
72 x = F.relu(self.conv1(data))
70 x = F.relu(self.conv1(data))
76 x = self.conv2(x)
74 x = self.conv2(x)
79 caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
77 caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
81 caps = self.squash(caps)
79 caps = self.squash(caps)
84 caps = self.digit_capsules(caps)
82 caps = self.digit_capsules(caps)
87 with torch.no_grad():
85 with torch.no_grad():
89 pred = (caps ** 2).sum(-1).argmax(-1)
87 pred = (caps ** 2).sum(-1).argmax(-1)
91 mask = torch.eye(10, device=data.device)[pred]
89 mask = torch.eye(10, device=data.device)[pred]
95 reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))
93 reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))
97 reconstructions = reconstructions.view(-1, 1, 28, 28)
-98
-99 return caps, reconstructions, pred
95 reconstructions = reconstructions.view(-1, 1, 28, 28)
+96
+97 return caps, reconstructions, pred
102class Configs(MNISTConfigs, SimpleTrainValidConfigs):
100class Configs(MNISTConfigs, SimpleTrainValidConfigs):
106 epochs: int = 10
-107 model: nn.Module = 'capsule_network_model'
-108 reconstruction_loss = nn.MSELoss()
-109 margin_loss = MarginLoss(n_labels=10)
-110 accuracy = AccuracyDirect()
104 epochs: int = 10
+105 model: nn.Module = 'capsule_network_model'
+106 reconstruction_loss = nn.MSELoss()
+107 margin_loss = MarginLoss(n_labels=10)
+108 accuracy = AccuracyDirect()
112 def init(self):
110 def init(self):
114 tracker.set_scalar('loss.*', True)
-115 tracker.set_scalar('accuracy.*', True)
112 tracker.set_scalar('loss.*', True)
+113 tracker.set_scalar('accuracy.*', True)
118 self.state_modules = [self.accuracy]
116 self.state_modules = [self.accuracy]
120 def step(self, batch: Any, batch_idx: BatchIndex):
118 def step(self, batch: Any, batch_idx: BatchIndex):
125 self.model.train(self.mode.is_train)
123 self.model.train(self.mode.is_train)
128 data, target = batch[0].to(self.device), batch[1].to(self.device)
126 data, target = batch[0].to(self.device), batch[1].to(self.device)
131 if self.mode.is_train:
-132 tracker.add_global_step(len(data))
129 if self.mode.is_train:
+130 tracker.add_global_step(len(data))
135 with self.mode.update(is_log_activations=batch_idx.is_last):
133 with self.mode.update(is_log_activations=batch_idx.is_last):
137 caps, reconstructions, pred = self.model(data)
135 caps, reconstructions, pred = self.model(data)
140 loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
-141 tracker.add("loss.", loss)
138 loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
+139 tracker.add("loss.", loss)
144 self.accuracy(pred, target)
-145
-146 if self.mode.is_train:
-147 loss.backward()
-148
-149 self.optimizer.step()
142 self.accuracy(pred, target)
+143
+144 if self.mode.is_train:
+145 loss.backward()
+146
+147 self.optimizer.step()
151 if batch_idx.is_last:
-152 tracker.add('model', self.model)
-153 self.optimizer.zero_grad()
-154
-155 tracker.save()
149 if batch_idx.is_last:
+150 tracker.add('model', self.model)
+151 self.optimizer.zero_grad()
+152
+153 tracker.save()
158@option(Configs.model)
-159def capsule_network_model(c: Configs):
156@option(Configs.model)
+157def capsule_network_model(c: Configs):
161 return MNISTCapsuleNetworkModel().to(c.device)
159 return MNISTCapsuleNetworkModel().to(c.device)
164def main():
162def main():
168 experiment.create(name='capsule_network_mnist')
-169 conf = Configs()
-170 experiment.configs(conf, {'optimizer.optimizer': 'Adam',
-171 'optimizer.learning_rate': 1e-3})
+ 166 experiment.create(name='capsule_network_mnist')
+167 conf = Configs()
+168 experiment.configs(conf, {'optimizer.optimizer': 'Adam',
+169 'optimizer.learning_rate': 1e-3})
+170
+171 experiment.add_pytorch_models({'model': conf.model})
172
-173 experiment.add_pytorch_models({'model': conf.model})
-174
-175 with experiment.start():
-176 conf.run()
-177
-178
-179if __name__ == '__main__':
-180 main()
+173 with experiment.start():
+174 conf.run()
+175
+176
+177if __name__ == '__main__':
+178 main()
109 tracker.set_scalar("accuracy.*", True)
-110 tracker.set_scalar("loss.*", True)
112 hook_model_outputs(self.mode, self.model, 'model')
113 hook_model_outputs(self.mode, self.model, 'model')
117 self.state_modules = [self.accuracy]
118 self.state_modules = [self.accuracy]
119 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
120 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
121 pass
122 pass
123 def step(self, batch: any, batch_idx: BatchIndex):
124 def step(self, batch: any, batch_idx: BatchIndex):
129 self.model.train(self.mode.is_train)
130 self.model.train(self.mode.is_train)
132 data, target = batch[0].to(self.device), batch[1].to(self.device)
133 data, target = batch[0].to(self.device), batch[1].to(self.device)
135 if self.mode.is_train:
-136 tracker.add_global_step(data.shape[0] * data.shape[1])
136 if self.mode.is_train:
+137 tracker.add_global_step(data.shape[0] * data.shape[1])
139 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
140 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
143 output, *_ = self.model(data)
144 output, *_ = self.model(data)
146 loss = self.loss_func(output, target)
-147 tracker.add("loss.", loss)
147 loss = self.loss_func(output, target)
+148 tracker.add("loss.", loss)
150 self.accuracy(output, target)
-151 self.accuracy.track()
-152
-153 self.other_metrics(output, target)
151 self.accuracy(output, target)
+152 self.accuracy.track()
+153
+154 self.other_metrics(output, target)
156 if self.mode.is_train:
157 if self.mode.is_train:
158 loss.backward()
159 loss.backward()
160 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
161 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
162 self.optimizer.step()
163 self.optimizer.step()
164 if batch_idx.is_last and self.is_log_model_params_grads:
-165 tracker.add('model', self.model)
165 if batch_idx.is_last and self.is_log_model_params_grads:
+166 tracker.add('model', self.model)
167 self.optimizer.zero_grad()
168 self.optimizer.zero_grad()
170 tracker.save()
171 tracker.save()
172 def sample(self):
173 def sample(self):
178 prompt = self.prompt
179 prompt = self.prompt
180 log = [(prompt, Text.subtle)]
181 log = [(prompt, Text.subtle)]
182 for i in monit.iterate('Sample', 25):
183 for i in monit.iterate('Sample', 25):
184 data = self.text.text_to_i(prompt).unsqueeze(-1)
-185 data = data.to(self.device)
185 data = self.text.text_to_i(prompt).unsqueeze(-1)
+186 data = data.to(self.device)
187 output, *_ = self.model(data)
188 output, *_ = self.model(data)
189 output = output.argmax(dim=-1).squeeze()
190 output = output.argmax(dim=-1).squeeze()
191 prompt += self.prompt_separator + self.text.itos[output[-1]]
192 prompt += self.prompt_separator + self.text.itos[output[-1]]
193 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
194 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
+195
+196 tracker.add({'sampled': prompt})
196 logger.log(log)
198 logger.log(log)
199@option(NLPAutoRegressionConfigs.optimizer)
-200def _optimizer(c: NLPAutoRegressionConfigs):
201@option(NLPAutoRegressionConfigs.optimizer)
+202def _optimizer(c: NLPAutoRegressionConfigs):
205 optimizer = OptimizerConfigs()
-206 optimizer.parameters = c.model.parameters()
-207 optimizer.optimizer = 'Adam'
-208 optimizer.d_model = c.d_model
-209
-210 return optimizer
207 optimizer = OptimizerConfigs()
+208 optimizer.parameters = c.model.parameters()
+209 optimizer.optimizer = 'Adam'
+210 optimizer.d_model = c.d_model
+211
+212 return optimizer
213@option(NLPAutoRegressionConfigs.n_tokens)
-214def _n_tokens(c: NLPAutoRegressionConfigs):
215@option(NLPAutoRegressionConfigs.n_tokens)
+216def _n_tokens(c: NLPAutoRegressionConfigs):
218 return c.text.n_tokens
220 return c.text.n_tokens
221@option(NLPAutoRegressionConfigs.tokenizer)
-222def basic_english():
223@option(NLPAutoRegressionConfigs.tokenizer)
+224def basic_english():
236 from torchtext.data import get_tokenizer
-237 return get_tokenizer('basic_english')
238 from torchtext.data import get_tokenizer
+239 return get_tokenizer('basic_english')
240def character_tokenizer(x: str):
242def character_tokenizer(x: str):
244 return list(x)
246 return list(x)
247@option(NLPAutoRegressionConfigs.tokenizer)
-248def character():
249@option(NLPAutoRegressionConfigs.tokenizer)
+250def character():
252 return character_tokenizer
254 return character_tokenizer
255@option(NLPAutoRegressionConfigs.text)
-256def tiny_shakespeare(c: NLPAutoRegressionConfigs):
257@option(NLPAutoRegressionConfigs.text)
+258def tiny_shakespeare(c: NLPAutoRegressionConfigs):
262 return TextFileDataset(
-263 lab.get_data_path() / 'tiny_shakespeare.txt',
-264 c.tokenizer,
-265 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
264 return TextFileDataset(
+265 lab.get_data_path() / 'tiny_shakespeare.txt',
+266 c.tokenizer,
+267 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
268@option(NLPAutoRegressionConfigs.train_loader)
-269def sequential_train_loader(c: NLPAutoRegressionConfigs):
270@option(NLPAutoRegressionConfigs.train_loader)
+271def sequential_train_loader(c: NLPAutoRegressionConfigs):
273 return SequentialDataLoader(text=c.text.train,
-274 dataset=c.text,
-275 batch_size=c.batch_size,
-276 seq_len=c.seq_len)
275 return SequentialDataLoader(text=c.text.train,
+276 dataset=c.text,
+277 batch_size=c.batch_size,
+278 seq_len=c.seq_len)
279@option(NLPAutoRegressionConfigs.valid_loader)
-280def sequential_valid_loader(c: NLPAutoRegressionConfigs):
281@option(NLPAutoRegressionConfigs.valid_loader)
+282def sequential_valid_loader(c: NLPAutoRegressionConfigs):
284 return SequentialDataLoader(text=c.text.valid,
-285 dataset=c.text,
-286 batch_size=c.batch_size,
-287 seq_len=c.seq_len)
286 return SequentialDataLoader(text=c.text.valid,
+287 dataset=c.text,
+288 batch_size=c.batch_size,
+289 seq_len=c.seq_len)
290def transpose_batch(batch):
292def transpose_batch(batch):
298 transposed_data = list(zip(*batch))
300 transposed_data = list(zip(*batch))
300 src = torch.stack(transposed_data[0], dim=1)
-301 tgt = torch.stack(transposed_data[1], dim=1)
-302
-303 return src, tgt
302 src = torch.stack(transposed_data[0], dim=1)
+303 tgt = torch.stack(transposed_data[1], dim=1)
+304
+305 return src, tgt
306@option(NLPAutoRegressionConfigs.train_loader)
-307def shuffled_train_loader(c: NLPAutoRegressionConfigs):
308@option(NLPAutoRegressionConfigs.train_loader)
+309def shuffled_train_loader(c: NLPAutoRegressionConfigs):
311 dataset = SequentialUnBatchedDataset(text=c.text.train,
-312 dataset=c.text,
-313 seq_len=c.seq_len)
-314 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
-315
-316 return DataLoader(dataset,
-317 batch_size=c.batch_size,
-318 collate_fn=transpose_batch,
-319 sampler=sampler)
313 dataset = SequentialUnBatchedDataset(text=c.text.train,
+314 dataset=c.text,
+315 seq_len=c.seq_len)
+316 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
+317
+318 return DataLoader(dataset,
+319 batch_size=c.batch_size,
+320 collate_fn=transpose_batch,
+321 sampler=sampler)
322@option(NLPAutoRegressionConfigs.valid_loader)
-323def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
324@option(NLPAutoRegressionConfigs.valid_loader)
+325def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
327 dataset = SequentialUnBatchedDataset(text=c.text.valid,
-328 dataset=c.text,
-329 seq_len=c.seq_len)
-330 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
-331
-332 return DataLoader(dataset,
-333 batch_size=c.batch_size,
-334 collate_fn=transpose_batch,
-335 sampler=sampler)
329 dataset = SequentialUnBatchedDataset(text=c.text.valid,
+330 dataset=c.text,
+331 seq_len=c.seq_len)
+332 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
+333
+334 return DataLoader(dataset,
+335 batch_size=c.batch_size,
+336 collate_fn=transpose_batch,
+337 sampler=sampler)
15import copy
-16
-17import torch
-18import torch.nn as nn
-19
-20from labml import experiment
-21from labml.configs import option
-22from labml_helpers.module import Module
-23from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-24from labml_nn.normalization.deep_norm import DeepNormTransformerLayer
-25from labml_nn.transformers import MultiHeadAttention
-26from labml_nn.transformers.feed_forward import FeedForward
14import copy
+15
+16import torch
+17import torch.nn as nn
+18
+19from labml import experiment
+20from labml.configs import option
+21from labml_helpers.module import Module
+22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+23from labml_nn.normalization.deep_norm import DeepNormTransformerLayer
+24from labml_nn.transformers import MultiHeadAttention
+25from labml_nn.transformers.feed_forward import FeedForward
29class AutoregressiveTransformer(Module):
28class AutoregressiveTransformer(Module):
36 def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: DeepNormTransformerLayer):
35 def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: DeepNormTransformerLayer):
43 super().__init__()
42 super().__init__()
45 self.transformer = nn.Sequential(*[copy.deepcopy(layer) for _ in range(n_layers)])
44 self.transformer = nn.Sequential(*[copy.deepcopy(layer) for _ in range(n_layers)])
48 self.emb = nn.Embedding(n_tokens, d_model)
47 self.emb = nn.Embedding(n_tokens, d_model)
50 self.readout = nn.Linear(d_model, n_tokens)
49 self.readout = nn.Linear(d_model, n_tokens)
52 def forward(self, x: torch.Tensor):
51 def forward(self, x: torch.Tensor):
57 x = self.emb(x)
56 x = self.emb(x)
59 x = self.transformer(x)
58 x = self.transformer(x)
61 x = self.readout(x)
60 x = self.readout(x)
64 return x, None
63 return x, None
67class Configs(NLPAutoRegressionConfigs):
66class Configs(NLPAutoRegressionConfigs):
76 model: AutoregressiveTransformer
75 model: AutoregressiveTransformer
79 n_layers: int = 64
78 n_layers: int = 64
82 deep_norm_alpha: float
-83 deep_norm_beta: float
81 deep_norm_alpha: float
+82 deep_norm_beta: float
86 n_heads: int = 4
85 n_heads: int = 4
88 d_model: int = 64
87 d_model: int = 64
90 d_k: int = 16
89 d_k: int = 16
93@option(Configs.deep_norm_alpha)
-94def _deep_norm_alpha(c: Configs):
92@option(Configs.deep_norm_alpha)
+93def _deep_norm_alpha(c: Configs):
100 return (2. * c.n_layers) ** (1. / 4.)
99 return (2. * c.n_layers) ** (1. / 4.)
103@option(Configs.deep_norm_beta)
-104def _deep_norm_beta(c: Configs):
102@option(Configs.deep_norm_beta)
+103def _deep_norm_beta(c: Configs):
110 return (8. * c.n_layers) ** -(1. / 4.)
109 return (8. * c.n_layers) ** -(1. / 4.)
113@option(Configs.model)
-114def _model(c: Configs):
112@option(Configs.model)
+113def _model(c: Configs):
118 m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
-119 DeepNormTransformerLayer(d_model=c.d_model,
-120 deep_norm_alpha=c.deep_norm_alpha,
-121 deep_norm_beta=c.deep_norm_beta,
-122 feed_forward=FeedForward(d_model=c.d_model,
-123 d_ff=c.d_model * 4),
-124 self_attn=MultiHeadAttention(c.n_heads, c.d_model,
-125 dropout_prob=0.0)))
-126
-127 return m.to(c.device)
117 m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
+118 DeepNormTransformerLayer(d_model=c.d_model,
+119 deep_norm_alpha=c.deep_norm_alpha,
+120 deep_norm_beta=c.deep_norm_beta,
+121 feed_forward=FeedForward(d_model=c.d_model,
+122 d_ff=c.d_model * 4),
+123 self_attn=MultiHeadAttention(c.n_heads, c.d_model,
+124 dropout_prob=0.0)))
+125
+126 return m.to(c.device)
130def main():
129def main():
135 experiment.create(name="deep_norm", writers={'screen', 'web_api', 'comet'})
134 experiment.create(name="deep_norm", writers={'screen', 'web_api', 'comet'})
137 conf = Configs()
136 conf = Configs()
139 experiment.configs(conf, {
138 experiment.configs(conf, {
141 'tokenizer': 'character',
140 'tokenizer': 'character',
143 'prompt_separator': '',
142 'prompt_separator': '',
145 'prompt': 'It is ',
144 'prompt': 'It is ',
147 'text': 'tiny_shakespeare',
146 'text': 'tiny_shakespeare',
150 'seq_len': 256,
149 'seq_len': 256,
152 'epochs': 32,
151 'epochs': 32,
154 'batch_size': 16,
153 'batch_size': 16,
156 'inner_iterations': 10,
155 'inner_iterations': 10,
159 'optimizer.optimizer': 'Adam',
-160 'optimizer.learning_rate': 3e-4,
-161 })
158 'optimizer.optimizer': 'Adam',
+159 'optimizer.learning_rate': 3e-4,
+160 })
164 experiment.add_pytorch_models({'model': conf.model})
163 experiment.add_pytorch_models({'model': conf.model})
167 with experiment.start():
166 with experiment.start():
169 conf.run()
168 conf.run()
173if __name__ == '__main__':
-174 main()
172if __name__ == '__main__':
+173 main()
This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers.
The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.
The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.
@@ -85,19 +86,18 @@Where is the number of layers in the encoder and is the number of layers in the decoder.
Refer to the paper for derivation.
Here is an experiment implementation that uses DeepNorm.
-75from typing import Union, List
-76
-77import torch
-78from torch import nn, Size
-79
-80from labml_nn.normalization.layer_norm import LayerNorm
-81from labml_nn.transformers import MultiHeadAttention
-82from labml_nn.transformers.feed_forward import FeedForward
-83from labml_nn.transformers.utils import subsequent_mask
74from typing import Union, List
+75
+76import torch
+77from torch import nn, Size
+78
+79from labml_nn.normalization.layer_norm import LayerNorm
+80from labml_nn.transformers import MultiHeadAttention
+81from labml_nn.transformers.feed_forward import FeedForward
+82from labml_nn.transformers.utils import subsequent_mask
86class DeepNorm(nn.Module):
85class DeepNorm(nn.Module):
93 def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
-94 eps: float = 1e-5,
-95 elementwise_affine: bool = True):
92 def __init__(self, alpha: float, normalized_shape: Union[int, List[int], Size], *,
+93 eps: float = 1e-5,
+94 elementwise_affine: bool = True):
102 super().__init__()
-103
-104 self.alpha = alpha
101 super().__init__()
+102
+103 self.alpha = alpha
106 self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
105 self.layer_norm = LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
108 def forward(self, x: torch.Tensor, gx: torch.Tensor):
107 def forward(self, x: torch.Tensor, gx: torch.Tensor):
114 return x + self.alpha * gx
113 return x + self.alpha * gx
117class DeepNormTransformerLayer(nn.Module):
116class DeepNormTransformerLayer(nn.Module):
124 def __init__(self, *,
-125 d_model: int,
-126 self_attn: MultiHeadAttention,
-127 feed_forward: FeedForward,
-128 deep_norm_alpha: float,
-129 deep_norm_beta: float,
-130 ):
123 def __init__(self, *,
+124 d_model: int,
+125 self_attn: MultiHeadAttention,
+126 feed_forward: FeedForward,
+127 deep_norm_alpha: float,
+128 deep_norm_beta: float,
+129 ):
138 super().__init__()
-139
-140 self.self_attn = self_attn
-141 self.feed_forward = feed_forward
137 super().__init__()
+138
+139 self.self_attn = self_attn
+140 self.feed_forward = feed_forward
143 self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
-144 self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])
142 self.self_attn_norm = DeepNorm(deep_norm_alpha, [d_model])
+143 self.feed_forward_norm = DeepNorm(deep_norm_alpha, [d_model])
147 with torch.no_grad():
146 with torch.no_grad():
149 feed_forward.layer1.weight *= deep_norm_beta
-150 feed_forward.layer2.weight *= deep_norm_beta
148 feed_forward.layer1.weight *= deep_norm_beta
+149 feed_forward.layer2.weight *= deep_norm_beta
153 self_attn.value.linear.weight *= deep_norm_beta
152 self_attn.value.linear.weight *= deep_norm_beta
155 self_attn.output.weight *= deep_norm_beta
154 self_attn.output.weight *= deep_norm_beta
158 self.mask = None
157 self.mask = None
160 def forward(self, x: torch.Tensor):
159 def forward(self, x: torch.Tensor):
165 if self.mask is None or self.mask.size(0) != len(x):
164 if self.mask is None or self.mask.size(0) != len(x):
167 self.mask = subsequent_mask(len(x)).to(x.device)
166 self.mask = subsequent_mask(len(x)).to(x.device)
170 x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))
169 x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))
172 x = self.feed_forward_norm(x, self.feed_forward(x))
171 x = self.feed_forward_norm(x, self.feed_forward(x))
175 return x
174 return x