diff --git a/Makefile b/Makefile
index 781da712..8848a160 100644
--- a/Makefile
+++ b/Makefile
@@ -38,8 +38,12 @@ docs-zh: ## Chinese Translation
cd labml_nn; pylit --translate zh --translate_cache ../translate_cache --remove_empty_sections --title_md -t ../../../pylit/templates/nn -d ../docs/zh -w *
docs: ## Render annotated HTML
+ mv docs/zh docs_zh
+ mv docs/si docs_si
find ./docs/ -name "*.html" -type f -delete
find ./docs/ -name "*.svg" -type f -delete
+ mv docs_si docs/si
+ mv docs_zh docs/zh
python utils/sitemap.py
python utils/diagrams.py
cd labml_nn; pylit --remove_empty_sections --title_md -t ../../../pylit/templates/nn -d ../docs -w *
diff --git a/docs/activations/fta/experiment.html b/docs/activations/fta/experiment.html
index 398fd7c7..85316cb4 100644
--- a/docs/activations/fta/experiment.html
+++ b/docs/activations/fta/experiment.html
@@ -1,5 +1,5 @@
-
+
@@ -76,24 +76,24 @@
#
-
+
Here we train a transformer that uses Fuzzy Tiling Activation in the Feed-Forward Network . We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.
However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.
-
22 import copy
-23
-24 import torch
-25 import torch.nn as nn
-26
-27 from labml import experiment
-28 from labml.configs import option
-29 from labml_helpers.module import Module
-30 from labml_nn.activations.fta import FTA
-31 from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-32 from labml_nn.transformers import MultiHeadAttention , TransformerLayer
-33 from labml_nn.transformers.utils import subsequent_mask
+
21 import copy
+22
+23 import torch
+24 import torch.nn as nn
+25
+26 from labml import experiment
+27 from labml.configs import option
+28 from labml_helpers.module import Module
+29 from labml_nn.activations.fta import FTA
+30 from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+31 from labml_nn.transformers import MultiHeadAttention , TransformerLayer
+32 from labml_nn.transformers.utils import subsequent_mask
@@ -105,7 +105,7 @@
-
36 class FeedForwardFTA ( nn . Module ):
+
35 class FeedForwardFTA ( nn . Module ):
@@ -124,9 +124,9 @@
-
41 def __init__ ( self , d_model : int , d_ff : int ,
-42 activation : FTA ,
-43 dropout : float = 0.1 ):
+
40 def __init__ ( self , d_model : int , d_ff : int ,
+41 activation : FTA ,
+42 dropout : float = 0.1 ):
@@ -137,7 +137,7 @@
@@ -149,7 +149,7 @@
-
52 self . layer1 = nn . Linear ( d_model , d_ff )
+
51 self . layer1 = nn . Linear ( d_model , d_ff )
@@ -161,7 +161,7 @@
-
54 self . layer2 = nn . Linear ( d_ff * activation . expansion_factor , d_model )
+
53 self . layer2 = nn . Linear ( d_ff * activation . expansion_factor , d_model )
@@ -173,7 +173,7 @@
-
56 self . dropout = nn . Dropout ( dropout )
+
55 self . dropout = nn . Dropout ( dropout )
@@ -185,7 +185,7 @@
-
58 self . activation = activation
+
57 self . activation = activation
@@ -196,7 +196,7 @@
-
60 def forward ( self , x : torch . Tensor ):
+
59 def forward ( self , x : torch . Tensor ):
@@ -208,7 +208,7 @@
-
62 x = self . activation ( self . layer1 ( x ))
+
61 x = self . activation ( self . layer1 ( x ))
@@ -220,7 +220,7 @@
@@ -232,7 +232,7 @@
@@ -245,7 +245,7 @@
-
69 class AutoregressiveTransformer ( Module ):
+
68 class AutoregressiveTransformer ( Module ):
@@ -265,7 +265,7 @@
-
77 def __init__ ( self , n_tokens : int , d_model : int , n_layers : int , layer : TransformerLayer ):
+
76 def __init__ ( self , n_tokens : int , d_model : int , n_layers : int , layer : TransformerLayer ):
@@ -276,7 +276,7 @@
@@ -289,7 +289,7 @@
-
86 self . transformer_layers = nn . ModuleList ([ copy . deepcopy ( layer ) for _ in range ( n_layers )])
+
85 self . transformer_layers = nn . ModuleList ([ copy . deepcopy ( layer ) for _ in range ( n_layers )])
@@ -301,7 +301,7 @@
-
89 self . emb = nn . Embedding ( n_tokens , d_model )
+
88 self . emb = nn . Embedding ( n_tokens , d_model )
@@ -313,7 +313,7 @@
-
91 self . readout = nn . Linear ( d_model , n_tokens )
+
90 self . readout = nn . Linear ( d_model , n_tokens )
@@ -325,7 +325,7 @@
@@ -339,7 +339,7 @@
-
96 def forward ( self , x : torch . Tensor ):
+
95 def forward ( self , x : torch . Tensor ):
@@ -351,7 +351,7 @@
-
101 if self . mask is None or self . mask . size ( 0 ) != len ( x ):
+
100 if self . mask is None or self . mask . size ( 0 ) != len ( x ):
@@ -363,7 +363,7 @@
-
103 self . mask = subsequent_mask ( len ( x )) . to ( x . device )
+
102 self . mask = subsequent_mask ( len ( x )) . to ( x . device )
@@ -375,7 +375,7 @@
@@ -387,8 +387,8 @@
-
108 for layer in self . transformer_layers :
-109 x = layer ( x = x , mask = self . mask )
+
107 for layer in self . transformer_layers :
+108 x = layer ( x = x , mask = self . mask )
@@ -400,7 +400,7 @@
@@ -412,7 +412,7 @@
@@ -426,7 +426,7 @@
-
117 class Configs ( NLPAutoRegressionConfigs ):
+
116 class Configs ( NLPAutoRegressionConfigs ):
@@ -438,7 +438,7 @@
-
126 model : AutoregressiveTransformer
+
125 model : AutoregressiveTransformer
@@ -450,7 +450,7 @@
@@ -462,8 +462,8 @@
-
132 deep_norm_alpha : float
-133 deep_norm_beta : float
+
131 deep_norm_alpha : float
+132 deep_norm_beta : float
@@ -475,7 +475,7 @@
@@ -487,7 +487,7 @@
@@ -499,7 +499,7 @@
@@ -511,7 +511,7 @@
@@ -523,10 +523,10 @@
-
145 fta_lower_limit : float = - 1.
-146 fta_upper_limit : float = + 1.
-147 fta_delta : float = 0.2
-148 fta_eta : float = 0.05
+
144 fta_lower_limit : float = - 1.
+145 fta_upper_limit : float = + 1.
+146 fta_delta : float = 0.2
+147 fta_eta : float = 0.05
@@ -538,8 +538,8 @@
-
151 @option ( Configs . model )
-152 def _model ( c : Configs ):
+
150 @option ( Configs . model )
+151 def _model ( c : Configs ):
@@ -551,7 +551,7 @@
-
158 fta = FTA ( c . fta_lower_limit , c . fta_upper_limit , c . fta_delta , c . fta_eta )
+
157 fta = FTA ( c . fta_lower_limit , c . fta_upper_limit , c . fta_delta , c . fta_eta )
@@ -565,15 +565,15 @@
-
162 m = AutoregressiveTransformer ( c . n_tokens , c . d_model , c . n_layers ,
-163 TransformerLayer ( d_model = c . d_model ,
-164 feed_forward = FeedForwardFTA ( d_model = c . d_model ,
-165 d_ff = c . d_ff ,
-166 activation = fta ,
-167 dropout = 0.1 ),
-168 self_attn = MultiHeadAttention ( c . n_heads , c . d_model ,
-169 dropout_prob = 0.0 ),
-170 dropout_prob = 0.0 ))
+
161 m = AutoregressiveTransformer ( c . n_tokens , c . d_model , c . n_layers ,
+162 TransformerLayer ( d_model = c . d_model ,
+163 feed_forward = FeedForwardFTA ( d_model = c . d_model ,
+164 d_ff = c . d_ff ,
+165 activation = fta ,
+166 dropout = 0.1 ),
+167 self_attn = MultiHeadAttention ( c . n_heads , c . d_model ,
+168 dropout_prob = 0.0 ),
+169 dropout_prob = 0.0 ))
@@ -585,7 +585,7 @@
-
173 return m . to ( c . device )
+
172 return m . to ( c . device )
@@ -597,7 +597,7 @@
@@ -609,7 +609,7 @@
-
181 experiment . create ( name = "fta" , writers = { 'screen' , 'comet' , 'labml' })
+
180 experiment . create ( name = "fta" , writers = { 'screen' , 'labml' })
@@ -621,7 +621,7 @@
@@ -633,7 +633,7 @@
-
185 experiment . configs ( conf , {
+
184 experiment . configs ( conf , {
@@ -645,7 +645,7 @@
-
187 'tokenizer' : 'character' ,
+
186 'tokenizer' : 'character' ,
@@ -657,7 +657,7 @@
-
189 'prompt_separator' : '' ,
+
188 'prompt_separator' : '' ,
@@ -669,7 +669,7 @@
@@ -681,7 +681,7 @@
-
193 'text' : 'tiny_shakespeare' ,
+
192 'text' : 'tiny_shakespeare' ,
@@ -693,7 +693,7 @@
@@ -705,7 +705,7 @@
@@ -717,7 +717,7 @@
@@ -729,7 +729,7 @@
-
202 'inner_iterations' : 10 ,
+
201 'inner_iterations' : 10 ,
@@ -741,9 +741,9 @@
-
205 'optimizer.optimizer' : 'Adam' ,
-206 'optimizer.learning_rate' : 3e-4 ,
-207 })
+
204 'optimizer.optimizer' : 'Adam' ,
+205 'optimizer.learning_rate' : 3e-4 ,
+206 })
@@ -755,7 +755,7 @@
-
210 experiment . add_pytorch_models ({ 'model' : conf . model })
+
209 experiment . add_pytorch_models ({ 'model' : conf . model })
@@ -767,7 +767,7 @@
-
213 with experiment . start ():
+
212 with experiment . start ():
@@ -779,7 +779,7 @@
@@ -791,8 +791,8 @@
-
219 if __name__ == '__main__' :
-220 main ()
+
218 if __name__ == '__main__' :
+219 main ()
Fuzzy Tiling Activations (FTA)
-
+
This is a PyTorch implementation/tutorial of Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online .
Fuzzy tiling activations are a form of sparse activations based on binning.
Binning is classification of a scalar value into a bin based on intervals. One problem with binning is that it gives zero gradients for most values (except at the boundary of bins). The other is that binning loses precision if the bin intervals are large.
@@ -99,8 +99,8 @@
-
62 import torch
-63 from torch import nn
+
61 import torch
+62 from torch import nn
@@ -112,7 +112,7 @@
@@ -131,7 +131,7 @@
-
71 def __init__ ( self , lower_limit : float , upper_limit : float , delta : float , eta : float ):
+
70 def __init__ ( self , lower_limit : float , upper_limit : float , delta : float , eta : float ):
@@ -142,7 +142,7 @@
@@ -154,7 +154,7 @@
-
81 self . c = nn . Parameter ( torch . arange ( lower_limit , upper_limit , delta ), requires_grad = False )
+
80 self . c = nn . Parameter ( torch . arange ( lower_limit , upper_limit , delta ), requires_grad = False )
@@ -166,7 +166,7 @@
-
83 self . expansion_factor = len ( self . c )
+
82 self . expansion_factor = len ( self . c )
@@ -178,7 +178,7 @@
@@ -190,7 +190,7 @@
@@ -203,7 +203,7 @@
-
89 def fuzzy_i_plus ( self , x : torch . Tensor ):
+
88 def fuzzy_i_plus ( self , x : torch . Tensor ):
@@ -214,7 +214,7 @@
-
95 return ( x <= self . eta ) * x + ( x > self . eta )
+
94 return ( x <= self . eta ) * x + ( x > self . eta )
@@ -225,7 +225,7 @@
-
97 def forward ( self , z : torch . Tensor ):
+
96 def forward ( self , z : torch . Tensor ):
@@ -237,7 +237,7 @@
-
100 z = z . view ( * z . shape , 1 )
+
99 z = z . view ( * z . shape , 1 )
@@ -249,7 +249,7 @@
-
103 z = 1. - self . fuzzy_i_plus ( torch . clip ( self . c - z , min = 0. ) + torch . clip ( z - self . delta - self . c , min = 0. ))
+
102 z = 1. - self . fuzzy_i_plus ( torch . clip ( self . c - z , min = 0. ) + torch . clip ( z - self . delta - self . c , min = 0. ))
@@ -261,7 +261,7 @@
-
107 return z . view ( * z . shape [: - 2 ], - 1 )
+
106 return z . view ( * z . shape [: - 2 ], - 1 )
@@ -273,7 +273,7 @@
@@ -284,7 +284,7 @@
-
114 from labml.logger import inspect
+
113 from labml.logger import inspect
@@ -296,7 +296,7 @@
-
117 a = FTA ( - 10 , 10 , 2. , 0.5 )
+
116 a = FTA ( - 10 , 10 , 2. , 0.5 )
@@ -308,7 +308,7 @@
@@ -320,7 +320,7 @@
-
121 inspect ( a . expansion_factor )
+
120 inspect ( a . expansion_factor )
@@ -332,7 +332,7 @@
-
124 z = torch . tensor ([ 1.1 , 2.2 , 3.3 , 4.4 , 5.5 , 6.6 , 7.7 , 8.8 , 9. , 10. , 11. ])
+
123 z = torch . tensor ([ 1.1 , 2.2 , 3.3 , 4.4 , 5.5 , 6.6 , 7.7 , 8.8 , 9. , 10. , 11. ])
@@ -344,7 +344,7 @@
@@ -356,11 +356,11 @@
-
128 inspect ( a ( z ))
+ 127 inspect ( a ( z ))
+128
129
-130
-131 if __name__ == '__main__' :
-132 _test ()
+130 if __name__ == '__main__' :
+131 _test ()
-
+
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai . Save the images inside data / celebA
folder .
The paper had used a exponential moving average of the model with a decay of 0.9999 . We have skipped this for simplicity.
-
21 from typing import List
-22
-23 import torch
-24 import torch.utils.data
-25 import torchvision
-26 from PIL import Image
-27
-28 from labml import lab , tracker , experiment , monit
-29 from labml.configs import BaseConfigs , option
-30 from labml_helpers.device import DeviceConfigs
-31 from labml_nn.diffusion.ddpm import DenoiseDiffusion
-32 from labml_nn.diffusion.ddpm.unet import UNet
+
20 from typing import List
+21
+22 import torch
+23 import torch.utils.data
+24 import torchvision
+25 from PIL import Image
+26
+27 from labml import lab , tracker , experiment , monit
+28 from labml.configs import BaseConfigs , option
+29 from labml_helpers.device import DeviceConfigs
+30 from labml_nn.diffusion.ddpm import DenoiseDiffusion
+31 from labml_nn.diffusion.ddpm.unet import UNet
@@ -106,7 +106,7 @@
-
35 class Configs ( BaseConfigs ):
+
34 class Configs ( BaseConfigs ):
@@ -119,7 +119,7 @@
-
42 device : torch . device = DeviceConfigs ()
+
41 device : torch . device = DeviceConfigs ()
@@ -131,7 +131,7 @@
@@ -143,7 +143,7 @@
-
47 diffusion : DenoiseDiffusion
+
46 diffusion : DenoiseDiffusion
@@ -155,7 +155,7 @@
-
50 image_channels : int = 3
+
49 image_channels : int = 3
@@ -167,7 +167,7 @@
@@ -179,7 +179,7 @@
@@ -192,7 +192,7 @@
-
57 channel_multipliers : List [ int ] = [ 1 , 2 , 2 , 4 ]
+
56 channel_multipliers : List [ int ] = [ 1 , 2 , 2 , 4 ]
@@ -204,7 +204,7 @@
-
59 is_attention : List [ int ] = [ False , False , False , True ]
+
58 is_attention : List [ int ] = [ False , False , False , True ]
@@ -216,7 +216,7 @@
@@ -228,7 +228,7 @@
@@ -240,7 +240,7 @@
@@ -252,7 +252,7 @@
-
68 learning_rate : float = 2e-5
+
67 learning_rate : float = 2e-5
@@ -264,7 +264,7 @@
@@ -276,7 +276,7 @@
-
74 dataset : torch . utils . data . Dataset
+
73 dataset : torch . utils . data . Dataset
@@ -288,7 +288,7 @@
-
76 data_loader : torch . utils . data . DataLoader
+
75 data_loader : torch . utils . data . DataLoader
@@ -300,7 +300,7 @@
-
79 optimizer : torch . optim . Adam
+
78 optimizer : torch . optim . Adam
@@ -311,7 +311,7 @@
@@ -323,12 +323,12 @@
-
83 self . eps_model = UNet (
-84 image_channels = self . image_channels ,
-85 n_channels = self . n_channels ,
-86 ch_mults = self . channel_multipliers ,
-87 is_attn = self . is_attention ,
-88 ) . to ( self . device )
+
82 self . eps_model = UNet (
+83 image_channels = self . image_channels ,
+84 n_channels = self . n_channels ,
+85 ch_mults = self . channel_multipliers ,
+86 is_attn = self . is_attention ,
+87 ) . to ( self . device )
@@ -340,11 +340,11 @@
-
91 self . diffusion = DenoiseDiffusion (
-92 eps_model = self . eps_model ,
-93 n_steps = self . n_steps ,
-94 device = self . device ,
-95 )
+
90 self . diffusion = DenoiseDiffusion (
+91 eps_model = self . eps_model ,
+92 n_steps = self . n_steps ,
+93 device = self . device ,
+94 )
@@ -356,7 +356,7 @@
-
98 self . data_loader = torch . utils . data . DataLoader ( self . dataset , self . batch_size , shuffle = True , pin_memory = True )
+
97 self . data_loader = torch . utils . data . DataLoader ( self . dataset , self . batch_size , shuffle = True , pin_memory = True )
@@ -368,7 +368,7 @@
-
100 self . optimizer = torch . optim . Adam ( self . eps_model . parameters (), lr = self . learning_rate )
+
99 self . optimizer = torch . optim . Adam ( self . eps_model . parameters (), lr = self . learning_rate )
@@ -380,7 +380,7 @@
-
103 tracker . set_image ( "sample" , True )
+
102 tracker . set_image ( "sample" , True )
@@ -392,7 +392,7 @@
@@ -403,7 +403,7 @@
-
109 with torch . no_grad ():
+
108 with torch . no_grad ():
@@ -415,8 +415,8 @@
-
111 x = torch . randn ([ self . n_samples , self . image_channels , self . image_size , self . image_size ],
-112 device = self . device )
+
110 x = torch . randn ([ self . n_samples , self . image_channels , self . image_size , self . image_size ],
+111 device = self . device )
@@ -428,7 +428,7 @@
-
115 for t_ in monit . iterate ( 'Sample' , self . n_steps ):
+
114 for t_ in monit . iterate ( 'Sample' , self . n_steps ):
@@ -440,7 +440,7 @@
-
117 t = self . n_steps - t_ - 1
+
116 t = self . n_steps - t_ - 1
@@ -452,7 +452,7 @@
-
119 x = self . diffusion . p_sample ( x , x . new_full (( self . n_samples ,), t , dtype = torch . long ))
+
118 x = self . diffusion . p_sample ( x , x . new_full (( self . n_samples ,), t , dtype = torch . long ))
@@ -464,7 +464,7 @@
-
122 tracker . save ( 'sample' , x )
+
121 tracker . save ( 'sample' , x )
@@ -476,7 +476,7 @@
@@ -488,7 +488,7 @@
-
130 for data in monit . iterate ( 'Train' , self . data_loader ):
+
129 for data in monit . iterate ( 'Train' , self . data_loader ):
@@ -500,7 +500,7 @@
-
132 tracker . add_global_step ()
+
131 tracker . add_global_step ()
@@ -512,7 +512,7 @@
-
134 data = data . to ( self . device )
+
133 data = data . to ( self . device )
@@ -524,7 +524,7 @@
-
137 self . optimizer . zero_grad ()
+
136 self . optimizer . zero_grad ()
@@ -536,7 +536,7 @@
-
139 loss = self . diffusion . loss ( data )
+
138 loss = self . diffusion . loss ( data )
@@ -548,7 +548,7 @@
@@ -560,7 +560,7 @@
-
143 self . optimizer . step ()
+
142 self . optimizer . step ()
@@ -572,7 +572,7 @@
-
145 tracker . save ( 'loss' , loss )
+
144 tracker . save ( 'loss' , loss )
@@ -584,7 +584,7 @@
@@ -595,7 +595,7 @@
-
151 for _ in monit . loop ( self . epochs ):
+
150 for _ in monit . loop ( self . epochs ):
@@ -607,7 +607,7 @@
@@ -619,7 +619,7 @@
@@ -631,7 +631,7 @@
@@ -643,7 +643,7 @@
-
159 experiment . save_checkpoint ()
+
158 experiment . save_checkpoint ()
@@ -655,7 +655,7 @@
-
162 class CelebADataset ( torch . utils . data . Dataset ):
+
161 class CelebADataset ( torch . utils . data . Dataset ):
@@ -666,8 +666,8 @@
-
167 def __init__ ( self , image_size : int ):
-168 super () . __init__ ()
+
166 def __init__ ( self , image_size : int ):
+167 super () . __init__ ()
@@ -679,7 +679,7 @@
-
171 folder = lab . get_data_path () / 'celebA'
+
170 folder = lab . get_data_path () / 'celebA'
@@ -691,7 +691,7 @@
-
173 self . _files = [ p for p in folder . glob ( f '**/*.jpg' )]
+
172 self . _files = [ p for p in folder . glob ( f '**/*.jpg' )]
@@ -703,10 +703,10 @@
-
176 self . _transform = torchvision . transforms . Compose ([
-177 torchvision . transforms . Resize ( image_size ),
-178 torchvision . transforms . ToTensor (),
-179 ])
+
175 self . _transform = torchvision . transforms . Compose ([
+176 torchvision . transforms . Resize ( image_size ),
+177 torchvision . transforms . ToTensor (),
+178 ])
@@ -718,7 +718,7 @@
@@ -729,7 +729,7 @@
-
185 return len ( self . _files )
+
184 return len ( self . _files )
@@ -741,7 +741,7 @@
-
187 def __getitem__ ( self , index : int ):
+
186 def __getitem__ ( self , index : int ):
@@ -752,8 +752,8 @@
-
191 img = Image . open ( self . _files [ index ])
-192 return self . _transform ( img )
+
190 img = Image . open ( self . _files [ index ])
+191 return self . _transform ( img )
@@ -765,8 +765,8 @@
-
195 @option ( Configs . dataset , 'CelebA' )
-196 def celeb_dataset ( c : Configs ):
+
194 @option ( Configs . dataset , 'CelebA' )
+195 def celeb_dataset ( c : Configs ):
@@ -777,7 +777,7 @@
-
200 return CelebADataset ( c . image_size )
+
199 return CelebADataset ( c . image_size )
@@ -789,7 +789,7 @@
-
203 class MNISTDataset ( torchvision . datasets . MNIST ):
+
202 class MNISTDataset ( torchvision . datasets . MNIST ):
@@ -800,13 +800,13 @@
-
208 def __init__ ( self , image_size ):
-209 transform = torchvision . transforms . Compose ([
-210 torchvision . transforms . Resize ( image_size ),
-211 torchvision . transforms . ToTensor (),
-212 ])
-213
-214 super () . __init__ ( str ( lab . get_data_path ()), train = True , download = True , transform = transform )
+
207 def __init__ ( self , image_size ):
+208 transform = torchvision . transforms . Compose ([
+209 torchvision . transforms . Resize ( image_size ),
+210 torchvision . transforms . ToTensor (),
+211 ])
+212
+213 super () . __init__ ( str ( lab . get_data_path ()), train = True , download = True , transform = transform )
@@ -817,8 +817,8 @@
-
216 def __getitem__ ( self , item ):
-217 return super () . __getitem__ ( item )[ 0 ]
+
215 def __getitem__ ( self , item ):
+216 return super () . __getitem__ ( item )[ 0 ]
@@ -830,8 +830,8 @@
-
220 @option ( Configs . dataset , 'MNIST' )
-221 def mnist_dataset ( c : Configs ):
+
219 @option ( Configs . dataset , 'MNIST' )
+220 def mnist_dataset ( c : Configs ):
@@ -842,7 +842,7 @@
-
225 return MNISTDataset ( c . image_size )
+
224 return MNISTDataset ( c . image_size )
@@ -853,7 +853,7 @@
@@ -865,7 +865,7 @@
-
230 experiment . create ( name = 'diffuse' , writers = { 'screen' , 'comet' })
+
229 experiment . create ( name = 'diffuse' , writers = { 'screen' , 'labml' })
@@ -877,7 +877,7 @@
@@ -889,11 +889,11 @@
-
236 experiment . configs ( configs , {
-237 'dataset' : 'CelebA' , # 'MNIST'
-238 'image_channels' : 3 , # 1,
-239 'epochs' : 100 , # 5,
-240 })
+
235 experiment . configs ( configs , {
+236 'dataset' : 'CelebA' , # 'MNIST'
+237 'image_channels' : 3 , # 1,
+238 'epochs' : 100 , # 5,
+239 })
@@ -905,7 +905,7 @@
@@ -917,7 +917,7 @@
-
246 experiment . add_pytorch_models ({ 'eps_model' : configs . eps_model })
+
245 experiment . add_pytorch_models ({ 'eps_model' : configs . eps_model })
@@ -929,8 +929,8 @@
-
249 with experiment . start ():
-250 configs . run ()
+
248 with experiment . start ():
+249 configs . run ()
@@ -942,8 +942,8 @@
-
254 if __name__ == '__main__' :
-255 main ()
+
253 if __name__ == '__main__' :
+254 main ()
Denoising Diffusion Probabilistic Models (DDPM)
-
+
This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models .
In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.
The following definitions and derivations show how this works. For details please refer to the paper .
@@ -278,7 +278,7 @@ s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z">​ ϵ , t ) ∥ ∥ ​ 2 ] ​ That is, we are training to predict the noise.
Simplified loss
-L s ​ im pl e ( θ ) = E t , x 0 ​ , ϵ ​ [ ∥ ∥ ​ ϵ − ϵ θ ​ ( α t ​ ˉ ​ L simple ​ ( θ ) = E t , x 0 ​ , ϵ ​ [ ∥ ∥ ​ ϵ − ϵ θ ​ ( α t ​ ˉ ​ 163 from typing import Tuple , Optional
-164
-165 import torch
-166 import torch.nn.functional as F
-167 import torch.utils.data
-168 from torch import nn
-169
-170 from labml_nn.diffusion.ddpm.utils import gather
+ 162 from typing import Tuple , Optional
+163
+164 import torch
+165 import torch.nn.functional as F
+166 import torch.utils.data
+167 from torch import nn
+168
+169 from labml_nn.diffusion.ddpm.utils import gather
@@ -326,7 +326,7 @@ M834 80h400000v40h-400000z">
173 class DenoiseDiffusion :
+ 172 class DenoiseDiffusion :
@@ -343,7 +343,7 @@ M834 80h400000v40h-400000z">
178 def __init__ ( self , eps_model : nn . Module , n_steps : int , device : torch . device ):
+ 177 def __init__ ( self , eps_model : nn . Module , n_steps : int , device : torch . device ):
@@ -354,8 +354,8 @@ M834 80h400000v40h-400000z">
184 super () . __init__ ()
-185 self . eps_model = eps_model
+ 183 super () . __init__ ()
+184 self . eps_model = eps_model
@@ -367,7 +367,7 @@ M834 80h400000v40h-400000z">
188 self . beta = torch . linspace ( 0.0001 , 0.02 , n_steps ) . to ( device )
+ 187 self . beta = torch . linspace ( 0.0001 , 0.02 , n_steps ) . to ( device )
@@ -379,7 +379,7 @@ M834 80h400000v40h-400000z">
191 self . alpha = 1. - self . beta
+ 190 self . alpha = 1. - self . beta
@@ -391,7 +391,7 @@ M834 80h400000v40h-400000z">
193 self . alpha_bar = torch . cumprod ( self . alpha , dim = 0 )
+ 192 self . alpha_bar = torch . cumprod ( self . alpha , dim = 0 )
@@ -403,7 +403,7 @@ M834 80h400000v40h-400000z">
195 self . n_steps = n_steps
+ 194 self . n_steps = n_steps
@@ -415,7 +415,7 @@ M834 80h400000v40h-400000z">
197 self . sigma2 = self . beta
+ 196 self . sigma2 = self . beta
@@ -438,7 +438,7 @@ c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z">​ x 0 ​ , ( 1 − α t ​ ˉ ​ ) I ) ​
-
199 def q_xt_x0 ( self , x0 : torch . Tensor , t : torch . Tensor ) -> Tuple [ torch . Tensor , torch . Tensor ]:
+
198 def q_xt_x0 ( self , x0 : torch . Tensor , t : torch . Tensor ) -> Tuple [ torch . Tensor , torch . Tensor ]:
@@ -461,7 +461,7 @@ M834 80h400000v40h-400000z">
209 mean = gather ( self . alpha_bar , t ) ** 0.5 * x0
+ 208 mean = gather ( self . alpha_bar , t ) ** 0.5 * x0
@@ -473,7 +473,7 @@ M834 80h400000v40h-400000z">
211 var = 1 - gather ( self . alpha_bar , t )
+ 210 var = 1 - gather ( self . alpha_bar , t )
@@ -485,7 +485,7 @@ M834 80h400000v40h-400000z">
213 return mean , var
+
@@ -508,7 +508,7 @@ c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z">​ x 0 ​ , ( 1 − α t ​ ˉ ​ ) I ) ​
-
215 def q_sample ( self , x0 : torch . Tensor , t : torch . Tensor , eps : Optional [ torch . Tensor ] = None ):
+
214 def q_sample ( self , x0 : torch . Tensor , t : torch . Tensor , eps : Optional [ torch . Tensor ] = None ):
@@ -520,8 +520,8 @@ M834 80h400000v40h-400000z">
225 if eps is None :
-226 eps = torch . randn_like ( x0 )
+ 224 if eps is None :
+225 eps = torch . randn_like ( x0 )
@@ -533,7 +533,7 @@ M834 80h400000v40h-400000z">
229 mean , var = self . q_xt_x0 ( x0 , t )
+ 228 mean , var = self . q_xt_x0 ( x0 , t )
@@ -545,7 +545,7 @@ M834 80h400000v40h-400000z">
231 return mean + ( var ** 0.5 ) * eps
+ 230 return mean + ( var ** 0.5 ) * eps
@@ -579,7 +579,7 @@ c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z">​ β t ​ ​ ϵ θ ​ ( x t ​ , t ) ) ​
-
233 def p_sample ( self , xt : torch . Tensor , t : torch . Tensor ):
+
232 def p_sample ( self , xt : torch . Tensor , t : torch . Tensor ):
@@ -591,7 +591,7 @@ M834 80h400000v40h-400000z">
247 eps_theta = self . eps_model ( xt , t )
+ 246 eps_theta = self . eps_model ( xt , t )
@@ -603,7 +603,7 @@ M834 80h400000v40h-400000z">
249 alpha_bar = gather ( self . alpha_bar , t )
+ 248 alpha_bar = gather ( self . alpha_bar , t )
@@ -615,7 +615,7 @@ M834 80h400000v40h-400000z">
251 alpha = gather ( self . alpha , t )
+ 250 alpha = gather ( self . alpha , t )
@@ -638,7 +638,7 @@ M834 80h400000v40h-400000z">
253 eps_coef = ( 1 - alpha ) / ( 1 - alpha_bar ) ** .5
+ 252 eps_coef = ( 1 - alpha ) / ( 1 - alpha_bar ) ** .5
@@ -672,7 +672,7 @@ M834 80h400000v40h-400000z">
256 mean = 1 / ( alpha ** 0.5 ) * ( xt - eps_coef * eps_theta )
+ 255 mean = 1 / ( alpha ** 0.5 ) * ( xt - eps_coef * eps_theta )
@@ -684,7 +684,7 @@ M834 80h400000v40h-400000z">
258 var = gather ( self . sigma2 , t )
+ 257 var = gather ( self . sigma2 , t )
@@ -696,7 +696,7 @@ M834 80h400000v40h-400000z">
261 eps = torch . randn ( xt . shape , device = xt . device )
+ 260 eps = torch . randn ( xt . shape , device = xt . device )
@@ -708,7 +708,7 @@ M834 80h400000v40h-400000z">
263 return mean + ( var ** .5 ) * eps
+ 262 return mean + ( var ** .5 ) * eps
@@ -717,7 +717,7 @@ M834 80h400000v40h-400000z">
L s ​ im pl e ( θ ) = E t , x 0 ​ , ϵ ​ [ ∥ ∥ ​ ϵ − ϵ θ ​ ( α t ​ ˉ ​ L simple ​ ( θ ) = E t , x 0 ​ , ϵ ​ [ ∥ ∥ ​ ϵ − ϵ θ ​ ( α t ​ ˉ ​ 265 def loss ( self , x0 : torch . Tensor , noise : Optional [ torch . Tensor ] = None ):
+ 264 def loss ( self , x0 : torch . Tensor , noise : Optional [ torch . Tensor ] = None ):
@@ -755,7 +755,7 @@ M834 80h400000v40h-400000z">
274 batch_size = x0 . shape [ 0 ]
+ 273 batch_size = x0 . shape [ 0 ]
@@ -767,7 +767,7 @@ M834 80h400000v40h-400000z">
276 t = torch . randint ( 0 , self . n_steps , ( batch_size ,), device = x0 . device , dtype = torch . long )
+ 275 t = torch . randint ( 0 , self . n_steps , ( batch_size ,), device = x0 . device , dtype = torch . long )
@@ -779,8 +779,8 @@ M834 80h400000v40h-400000z">
279 if noise is None :
-280 noise = torch . randn_like ( x0 )
+ 278 if noise is None :
+279 noise = torch . randn_like ( x0 )
@@ -792,7 +792,7 @@ M834 80h400000v40h-400000z">
283 xt = self . q_sample ( x0 , t , eps = noise )
+ 282 xt = self . q_sample ( x0 , t , eps = noise )
@@ -826,7 +826,7 @@ M834 80h400000v40h-400000z">
285 eps_theta = self . eps_model ( xt , t )
+ 284 eps_theta = self . eps_model ( xt , t )
@@ -838,7 +838,7 @@ M834 80h400000v40h-400000z">
288 return F . mse_loss ( noise , eps_theta )
+ 287 return F . mse_loss ( noise , eps_theta )
-
+
This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models .
In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.
Here is the UNet model that predicts the noise and training code . This file can generate samples and interpolations from a trained model.
diff --git a/docs/diffusion/ddpm/unet.html b/docs/diffusion/ddpm/unet.html
index 4eabaf42..1c2109ce 100644
--- a/docs/diffusion/ddpm/unet.html
+++ b/docs/diffusion/ddpm/unet.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/diffusion/ddpm/utils.html b/docs/diffusion/ddpm/utils.html
index 95995e82..a1e25c23 100644
--- a/docs/diffusion/ddpm/utils.html
+++ b/docs/diffusion/ddpm/utils.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/diffusion/index.html b/docs/diffusion/index.html
index a9e9ba1c..cef3d8d4 100644
--- a/docs/diffusion/index.html
+++ b/docs/diffusion/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/distillation/index.html b/docs/distillation/index.html
index d6f4c942..506d6787 100644
--- a/docs/distillation/index.html
+++ b/docs/distillation/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/distillation/large.html b/docs/distillation/large.html
index 840dbe3f..bca72b2d 100644
--- a/docs/distillation/large.html
+++ b/docs/distillation/large.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/distillation/readme.html b/docs/distillation/readme.html
index 74c52600..f296823f 100644
--- a/docs/distillation/readme.html
+++ b/docs/distillation/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/distillation/small.html b/docs/distillation/small.html
index dbcafdd6..ab221183 100644
--- a/docs/distillation/small.html
+++ b/docs/distillation/small.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/experiments/arithmetic_dataset.html b/docs/experiments/arithmetic_dataset.html
index 2b6880f4..b4af9a75 100644
--- a/docs/experiments/arithmetic_dataset.html
+++ b/docs/experiments/arithmetic_dataset.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/experiments/cifar10.html b/docs/experiments/cifar10.html
index 23853103..ee48359e 100644
--- a/docs/experiments/cifar10.html
+++ b/docs/experiments/cifar10.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/experiments/index.html b/docs/experiments/index.html
index bf1253f4..29b75e7a 100644
--- a/docs/experiments/index.html
+++ b/docs/experiments/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/experiments/mnist.html b/docs/experiments/mnist.html
index 5a50610c..01c5c7f0 100644
--- a/docs/experiments/mnist.html
+++ b/docs/experiments/mnist.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/experiments/nlp_autoregression.html b/docs/experiments/nlp_autoregression.html
index f1ee53d3..c3d46a70 100644
--- a/docs/experiments/nlp_autoregression.html
+++ b/docs/experiments/nlp_autoregression.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/experiments/nlp_classification.html b/docs/experiments/nlp_classification.html
index 6fde095c..a57bb5ee 100644
--- a/docs/experiments/nlp_classification.html
+++ b/docs/experiments/nlp_classification.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/cycle_gan/index.html b/docs/gan/cycle_gan/index.html
index 48b7deab..9bf9c7bc 100644
--- a/docs/gan/cycle_gan/index.html
+++ b/docs/gan/cycle_gan/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/cycle_gan/readme.html b/docs/gan/cycle_gan/readme.html
index 2aee81a7..811521cc 100644
--- a/docs/gan/cycle_gan/readme.html
+++ b/docs/gan/cycle_gan/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/dcgan/index.html b/docs/gan/dcgan/index.html
index dc5492c3..48857c71 100644
--- a/docs/gan/dcgan/index.html
+++ b/docs/gan/dcgan/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/dcgan/readme.html b/docs/gan/dcgan/readme.html
index 394b9448..91207ece 100644
--- a/docs/gan/dcgan/readme.html
+++ b/docs/gan/dcgan/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/index.html b/docs/gan/index.html
index 14bfcfbe..4983e0c8 100644
--- a/docs/gan/index.html
+++ b/docs/gan/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/original/experiment.html b/docs/gan/original/experiment.html
index 5629f146..03b2a853 100644
--- a/docs/gan/original/experiment.html
+++ b/docs/gan/original/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/original/index.html b/docs/gan/original/index.html
index 0352128b..1fe45cb9 100644
--- a/docs/gan/original/index.html
+++ b/docs/gan/original/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/original/readme.html b/docs/gan/original/readme.html
index 6ca6cf47..2fb2c9b7 100644
--- a/docs/gan/original/readme.html
+++ b/docs/gan/original/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/stylegan/experiment.html b/docs/gan/stylegan/experiment.html
index a0dd4f62..16fcfb4e 100644
--- a/docs/gan/stylegan/experiment.html
+++ b/docs/gan/stylegan/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/stylegan/index.html b/docs/gan/stylegan/index.html
index eecbe12d..f44a19d1 100644
--- a/docs/gan/stylegan/index.html
+++ b/docs/gan/stylegan/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/stylegan/readme.html b/docs/gan/stylegan/readme.html
index 8a917138..37dd4d53 100644
--- a/docs/gan/stylegan/readme.html
+++ b/docs/gan/stylegan/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/wasserstein/experiment.html b/docs/gan/wasserstein/experiment.html
index b71c631f..482867e2 100644
--- a/docs/gan/wasserstein/experiment.html
+++ b/docs/gan/wasserstein/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/wasserstein/gradient_penalty/experiment.html b/docs/gan/wasserstein/gradient_penalty/experiment.html
index 9b544807..16e2889b 100644
--- a/docs/gan/wasserstein/gradient_penalty/experiment.html
+++ b/docs/gan/wasserstein/gradient_penalty/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/wasserstein/gradient_penalty/index.html b/docs/gan/wasserstein/gradient_penalty/index.html
index ecf739dc..432392e7 100644
--- a/docs/gan/wasserstein/gradient_penalty/index.html
+++ b/docs/gan/wasserstein/gradient_penalty/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/wasserstein/gradient_penalty/readme.html b/docs/gan/wasserstein/gradient_penalty/readme.html
index d1b446a4..0447fbc6 100644
--- a/docs/gan/wasserstein/gradient_penalty/readme.html
+++ b/docs/gan/wasserstein/gradient_penalty/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/wasserstein/index.html b/docs/gan/wasserstein/index.html
index 5735c7c5..2baf8219 100644
--- a/docs/gan/wasserstein/index.html
+++ b/docs/gan/wasserstein/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/gan/wasserstein/readme.html b/docs/gan/wasserstein/readme.html
index 88b99582..cbb6d5e2 100644
--- a/docs/gan/wasserstein/readme.html
+++ b/docs/gan/wasserstein/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/graphs/gat/experiment.html b/docs/graphs/gat/experiment.html
index 0eb6929e..eb5eb6c7 100644
--- a/docs/graphs/gat/experiment.html
+++ b/docs/graphs/gat/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/graphs/gat/index.html b/docs/graphs/gat/index.html
index 05600b81..2a66954c 100644
--- a/docs/graphs/gat/index.html
+++ b/docs/graphs/gat/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/graphs/gat/readme.html b/docs/graphs/gat/readme.html
index c7ac59e1..9d7662fd 100644
--- a/docs/graphs/gat/readme.html
+++ b/docs/graphs/gat/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/graphs/gatv2/experiment.html b/docs/graphs/gatv2/experiment.html
index 4877a782..b7e89d4a 100644
--- a/docs/graphs/gatv2/experiment.html
+++ b/docs/graphs/gatv2/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/graphs/gatv2/index.html b/docs/graphs/gatv2/index.html
index 8e596a02..06e7693e 100644
--- a/docs/graphs/gatv2/index.html
+++ b/docs/graphs/gatv2/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/graphs/gatv2/readme.html b/docs/graphs/gatv2/readme.html
index 46e307be..220179db 100644
--- a/docs/graphs/gatv2/readme.html
+++ b/docs/graphs/gatv2/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/graphs/index.html b/docs/graphs/index.html
index b1674934..05248b04 100644
--- a/docs/graphs/index.html
+++ b/docs/graphs/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/hypernetworks/experiment.html b/docs/hypernetworks/experiment.html
index 78d5273e..9b3c1f52 100644
--- a/docs/hypernetworks/experiment.html
+++ b/docs/hypernetworks/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/hypernetworks/hyper_lstm.html b/docs/hypernetworks/hyper_lstm.html
index 0a55f7ca..4cde7abb 100644
--- a/docs/hypernetworks/hyper_lstm.html
+++ b/docs/hypernetworks/hyper_lstm.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/hypernetworks/index.html b/docs/hypernetworks/index.html
index cfb06468..009263d1 100644
--- a/docs/hypernetworks/index.html
+++ b/docs/hypernetworks/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/index.html b/docs/index.html
index 677273bb..1f06b6d0 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/lstm/index.html b/docs/lstm/index.html
index 9dd24040..8225d9d6 100644
--- a/docs/lstm/index.html
+++ b/docs/lstm/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/checkpoint.html b/docs/neox/checkpoint.html
index ebfc64b4..d471dffa 100644
--- a/docs/neox/checkpoint.html
+++ b/docs/neox/checkpoint.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/evaluation/half_precision.html b/docs/neox/evaluation/half_precision.html
index 0df43082..0d9774da 100644
--- a/docs/neox/evaluation/half_precision.html
+++ b/docs/neox/evaluation/half_precision.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/evaluation/index.html b/docs/neox/evaluation/index.html
index 1cac38ce..2c61e919 100644
--- a/docs/neox/evaluation/index.html
+++ b/docs/neox/evaluation/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/evaluation/llm_int8.html b/docs/neox/evaluation/llm_int8.html
index b12f2a83..cff95374 100644
--- a/docs/neox/evaluation/llm_int8.html
+++ b/docs/neox/evaluation/llm_int8.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/index.html b/docs/neox/index.html
index 43f79328..f7922ddf 100644
--- a/docs/neox/index.html
+++ b/docs/neox/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/model.html b/docs/neox/model.html
index b0906adc..9092a600 100644
--- a/docs/neox/model.html
+++ b/docs/neox/model.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/readme.html b/docs/neox/readme.html
index 149419b7..91612c7f 100644
--- a/docs/neox/readme.html
+++ b/docs/neox/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/samples/finetune.html b/docs/neox/samples/finetune.html
index fb5d7435..81def7f7 100644
--- a/docs/neox/samples/finetune.html
+++ b/docs/neox/samples/finetune.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/samples/generate.html b/docs/neox/samples/generate.html
index 8228b9a2..5e84f046 100644
--- a/docs/neox/samples/generate.html
+++ b/docs/neox/samples/generate.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/samples/index.html b/docs/neox/samples/index.html
index cac8a13e..93168d6f 100644
--- a/docs/neox/samples/index.html
+++ b/docs/neox/samples/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/samples/llm_int8.html b/docs/neox/samples/llm_int8.html
index b3d024dd..85111093 100644
--- a/docs/neox/samples/llm_int8.html
+++ b/docs/neox/samples/llm_int8.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/tokenizer.html b/docs/neox/tokenizer.html
index 420b22a2..0308d2fc 100644
--- a/docs/neox/tokenizer.html
+++ b/docs/neox/tokenizer.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/utils/cache.html b/docs/neox/utils/cache.html
index aad9ee87..96711f8a 100644
--- a/docs/neox/utils/cache.html
+++ b/docs/neox/utils/cache.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/utils/finetune.html b/docs/neox/utils/finetune.html
index e11a8813..86370aac 100644
--- a/docs/neox/utils/finetune.html
+++ b/docs/neox/utils/finetune.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/utils/index.html b/docs/neox/utils/index.html
index 0806cc80..7e144c03 100644
--- a/docs/neox/utils/index.html
+++ b/docs/neox/utils/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/utils/llm_int8.html b/docs/neox/utils/llm_int8.html
index a4e12c6c..8588fdb4 100644
--- a/docs/neox/utils/llm_int8.html
+++ b/docs/neox/utils/llm_int8.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/utils/text_dataset.html b/docs/neox/utils/text_dataset.html
index c1b493a4..09479080 100644
--- a/docs/neox/utils/text_dataset.html
+++ b/docs/neox/utils/text_dataset.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/neox/utils/trainer.html b/docs/neox/utils/trainer.html
index 37a53cbc..d43bce74 100644
--- a/docs/neox/utils/trainer.html
+++ b/docs/neox/utils/trainer.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/normalization/batch_channel_norm/index.html b/docs/normalization/batch_channel_norm/index.html
index dbab9406..1f1799c4 100644
--- a/docs/normalization/batch_channel_norm/index.html
+++ b/docs/normalization/batch_channel_norm/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/normalization/batch_norm/cifar10.html b/docs/normalization/batch_norm/cifar10.html
index 211ef380..fbadc07e 100644
--- a/docs/normalization/batch_norm/cifar10.html
+++ b/docs/normalization/batch_norm/cifar10.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/normalization/batch_norm/index.html b/docs/normalization/batch_norm/index.html
index be959d8c..0d52615d 100644
--- a/docs/normalization/batch_norm/index.html
+++ b/docs/normalization/batch_norm/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/normalization/batch_norm/mnist.html b/docs/normalization/batch_norm/mnist.html
index 7cb1b069..3ea3eb98 100644
--- a/docs/normalization/batch_norm/mnist.html
+++ b/docs/normalization/batch_norm/mnist.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/normalization/batch_norm/readme.html b/docs/normalization/batch_norm/readme.html
index 13394ada..706e623e 100644
--- a/docs/normalization/batch_norm/readme.html
+++ b/docs/normalization/batch_norm/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/normalization/deep_norm/experiment.html b/docs/normalization/deep_norm/experiment.html
index f2fe3b91..e1ba372a 100644
--- a/docs/normalization/deep_norm/experiment.html
+++ b/docs/normalization/deep_norm/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/normalization/deep_norm/index.html b/docs/normalization/deep_norm/index.html
index 498c4919..242cedb3 100644
--- a/docs/normalization/deep_norm/index.html
+++ b/docs/normalization/deep_norm/index.html
@@ -1,5 +1,5 @@
-
+
@@ -76,7 +76,7 @@
#
DeepNorm
-
+
This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers .
The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.
The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.
@@ -95,15 +95,15 @@
-
74 from typing import Union , List
-75
-76 import torch
-77 from torch import nn , Size
-78
-79 from labml_nn.normalization.layer_norm import LayerNorm
-80 from labml_nn.transformers import MultiHeadAttention
-81 from labml_nn.transformers.feed_forward import FeedForward
-82 from labml_nn.transformers.utils import subsequent_mask
+
73 from typing import Union , List
+74
+75 import torch
+76 from torch import nn , Size
+77
+78 from labml_nn.normalization.layer_norm import LayerNorm
+79 from labml_nn.transformers import MultiHeadAttention
+80 from labml_nn.transformers.feed_forward import FeedForward
+81 from labml_nn.transformers.utils import subsequent_mask
@@ -116,7 +116,7 @@
-
85 class DeepNorm ( nn . Module ):
+
84 class DeepNorm ( nn . Module ):
@@ -135,9 +135,9 @@
-
92 def __init__ ( self , alpha : float , normalized_shape : Union [ int , List [ int ], Size ], * ,
-93 eps : float = 1e-5 ,
-94 elementwise_affine : bool = True ):
+
91 def __init__ ( self , alpha : float , normalized_shape : Union [ int , List [ int ], Size ], * ,
+92 eps : float = 1e-5 ,
+93 elementwise_affine : bool = True ):
@@ -148,9 +148,9 @@
-
101 super () . __init__ ()
-102
-103 self . alpha = alpha
+
100 super () . __init__ ()
+101
+102 self . alpha = alpha
@@ -162,7 +162,7 @@
-
105 self . layer_norm = LayerNorm ( normalized_shape , eps = eps , elementwise_affine = elementwise_affine )
+
104 self . layer_norm = LayerNorm ( normalized_shape , eps = eps , elementwise_affine = elementwise_affine )
@@ -177,7 +177,7 @@
-
107 def forward ( self , x : torch . Tensor , gx : torch . Tensor ):
+
106 def forward ( self , x : torch . Tensor , gx : torch . Tensor ):
@@ -189,7 +189,7 @@
-
113 return self . layer_norm ( x + self . alpha * gx )
+
112 return self . layer_norm ( x + self . alpha * gx )
@@ -202,7 +202,7 @@
-
116 class DeepNormTransformerLayer ( nn . Module ):
+
115 class DeepNormTransformerLayer ( nn . Module ):
@@ -223,13 +223,13 @@
-
123 def __init__ ( self , * ,
-124 d_model : int ,
-125 self_attn : MultiHeadAttention ,
-126 feed_forward : FeedForward ,
-127 deep_norm_alpha : float ,
-128 deep_norm_beta : float ,
-129 ):
+
122 def __init__ ( self , * ,
+123 d_model : int ,
+124 self_attn : MultiHeadAttention ,
+125 feed_forward : FeedForward ,
+126 deep_norm_alpha : float ,
+127 deep_norm_beta : float ,
+128 ):
@@ -240,10 +240,10 @@
-
137 super () . __init__ ()
-138
-139 self . self_attn = self_attn
-140 self . feed_forward = feed_forward
+
136 super () . __init__ ()
+137
+138 self . self_attn = self_attn
+139 self . feed_forward = feed_forward
@@ -255,8 +255,8 @@
-
142 self . self_attn_norm = DeepNorm ( deep_norm_alpha , [ d_model ])
-143 self . feed_forward_norm = DeepNorm ( deep_norm_alpha , [ d_model ])
+
141 self . self_attn_norm = DeepNorm ( deep_norm_alpha , [ d_model ])
+142 self . feed_forward_norm = DeepNorm ( deep_norm_alpha , [ d_model ])
@@ -268,7 +268,7 @@
-
146 with torch . no_grad ():
+
145 with torch . no_grad ():
@@ -280,8 +280,8 @@
-
148 feed_forward . layer1 . weight *= deep_norm_beta
-149 feed_forward . layer2 . weight *= deep_norm_beta
+
147 feed_forward . layer1 . weight *= deep_norm_beta
+148 feed_forward . layer2 . weight *= deep_norm_beta
@@ -293,7 +293,7 @@
-
152 self_attn . value . linear . weight *= deep_norm_beta
+
151 self_attn . value . linear . weight *= deep_norm_beta
@@ -305,7 +305,7 @@
-
154 self_attn . output . weight *= deep_norm_beta
+
153 self_attn . output . weight *= deep_norm_beta
@@ -317,7 +317,7 @@
@@ -331,7 +331,7 @@
-
159 def forward ( self , x : torch . Tensor ):
+
158 def forward ( self , x : torch . Tensor ):
@@ -343,7 +343,7 @@
-
164 if self . mask is None or self . mask . size ( 0 ) != len ( x ):
+
163 if self . mask is None or self . mask . size ( 0 ) != len ( x ):
@@ -355,7 +355,7 @@
-
166 self . mask = subsequent_mask ( len ( x )) . to ( x . device )
+
165 self . mask = subsequent_mask ( len ( x )) . to ( x . device )
@@ -367,7 +367,7 @@
-
169 x = self . self_attn_norm ( x , self . self_attn ( query = x , key = x , value = x , mask = self . mask ))
+
168 x = self . self_attn_norm ( x , self . self_attn ( query = x , key = x , value = x , mask = self . mask ))
@@ -379,7 +379,7 @@
-
171 x = self . feed_forward_norm ( x , self . feed_forward ( x ))
+
170 x = self . feed_forward_norm ( x , self . feed_forward ( x ))
@@ -391,7 +391,7 @@
Transformer Auto-Regression Experiment
-
+
This trains a simple transformer introduced in Attention Is All You Need on an NLP auto-regression task (with Tiny Shakespeare dataset).
-
17 import torch
-18 from torch import nn
-19
-20 from labml import experiment
-21 from labml.configs import option
-22 from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-23 from labml_nn.transformers import TransformerConfigs , Encoder
-24 from labml_nn.transformers.utils import subsequent_mask
+
16 import torch
+17 from torch import nn
+18
+19 from labml import experiment
+20 from labml.configs import option
+21 from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+22 from labml_nn.transformers import TransformerConfigs , Encoder
+23 from labml_nn.transformers.utils import subsequent_mask
@@ -100,7 +100,7 @@
-
27 class AutoregressiveTransformer ( nn . Module ):
+
26 class AutoregressiveTransformer ( nn . Module ):
@@ -117,7 +117,7 @@
-
31 def __init__ ( self , encoder : Encoder , src_embed : nn . Module , generator : nn . Module ):
+
30 def __init__ ( self , encoder : Encoder , src_embed : nn . Module , generator : nn . Module ):
@@ -128,10 +128,10 @@
-
38 super () . __init__ ()
-39 self . src_embed = src_embed
-40 self . encoder = encoder
-41 self . generator = generator
+
37 super () . __init__ ()
+38 self . src_embed = src_embed
+39 self . encoder = encoder
+40 self . generator = generator
@@ -143,7 +143,7 @@
@@ -154,7 +154,7 @@
-
46 def forward ( self , x : torch . Tensor ):
+
45 def forward ( self , x : torch . Tensor ):
@@ -166,7 +166,7 @@
-
49 if self . mask is None or self . mask . size ( 0 ) != len ( x ):
+
48 if self . mask is None or self . mask . size ( 0 ) != len ( x ):
@@ -178,7 +178,7 @@
-
51 self . mask = subsequent_mask ( len ( x )) . to ( x . device )
+
50 self . mask = subsequent_mask ( len ( x )) . to ( x . device )
@@ -190,7 +190,7 @@
@@ -202,7 +202,7 @@
-
55 x = self . encoder ( x , self . mask )
+
54 x = self . encoder ( x , self . mask )
@@ -214,7 +214,7 @@
@@ -226,7 +226,7 @@
@@ -240,7 +240,7 @@
-
64 class Configs ( NLPAutoRegressionConfigs ):
+
63 class Configs ( NLPAutoRegressionConfigs ):
@@ -252,7 +252,7 @@
-
73 model : AutoregressiveTransformer
+
72 model : AutoregressiveTransformer
@@ -264,7 +264,7 @@
-
75 transformer : TransformerConfigs
+
74 transformer : TransformerConfigs
@@ -276,8 +276,8 @@
-
78 @option ( Configs . transformer , 'Transformer' )
-79 def _transformer_configs ( c : Configs ):
+
77 @option ( Configs . transformer , 'Transformer' )
+78 def _transformer_configs ( c : Configs ):
@@ -289,7 +289,7 @@
-
86 conf = TransformerConfigs ()
+
85 conf = TransformerConfigs ()
@@ -301,8 +301,8 @@
-
88 conf . n_src_vocab = c . n_tokens
-89 conf . n_tgt_vocab = c . n_tokens
+
87 conf . n_src_vocab = c . n_tokens
+88 conf . n_tgt_vocab = c . n_tokens
@@ -314,7 +314,7 @@
-
91 conf . d_model = c . d_model
+
90 conf . d_model = c . d_model
@@ -326,7 +326,7 @@
@@ -338,8 +338,8 @@
-
97 @option ( Configs . model )
-98 def _model ( c : Configs ):
+
96 @option ( Configs . model )
+97 def _model ( c : Configs ):
@@ -350,11 +350,11 @@
-
102 m = AutoregressiveTransformer ( c . transformer . encoder ,
-103 c . transformer . src_embed ,
-104 c . transformer . generator ) . to ( c . device )
-105
-106 return m
+
101 m = AutoregressiveTransformer ( c . transformer . encoder ,
+102 c . transformer . src_embed ,
+103 c . transformer . generator ) . to ( c . device )
+104
+105 return m
@@ -365,7 +365,7 @@
@@ -377,7 +377,7 @@
-
111 experiment . create ( name = "transformer" )
+
110 experiment . create ( name = "transformer" )
@@ -389,7 +389,7 @@
@@ -401,7 +401,7 @@
-
115 experiment . configs ( conf , {
+
114 experiment . configs ( conf , {
@@ -413,7 +413,7 @@
-
117 'tokenizer' : 'character' ,
+
116 'tokenizer' : 'character' ,
@@ -425,7 +425,7 @@
-
119 'prompt_separator' : '' ,
+
118 'prompt_separator' : '' ,
@@ -437,7 +437,7 @@
@@ -449,7 +449,7 @@
-
123 'text' : 'tiny_shakespeare' ,
+
122 'text' : 'tiny_shakespeare' ,
@@ -461,7 +461,7 @@
@@ -473,7 +473,7 @@
@@ -485,7 +485,7 @@
@@ -497,7 +497,7 @@
-
133 'inner_iterations' : 10 ,
+
132 'inner_iterations' : 10 ,
@@ -509,9 +509,9 @@
-
136 'd_model' : 256 ,
-137 'transformer.n_heads' : 16 ,
-138 'transformer.ffn.d_ff' : 1024 ,
+
135 'd_model' : 256 ,
+136 'transformer.n_heads' : 16 ,
+137 'transformer.ffn.d_ff' : 1024 ,
@@ -523,9 +523,9 @@
-
141 'optimizer.optimizer' : 'Noam' ,
-142 'optimizer.learning_rate' : 1. ,
-143 })
+
140 'optimizer.optimizer' : 'Noam' ,
+141 'optimizer.learning_rate' : 1. ,
+142 })
@@ -537,7 +537,7 @@
-
146 experiment . add_pytorch_models ({ 'model' : conf . model })
+
145 experiment . add_pytorch_models ({ 'model' : conf . model })
@@ -549,7 +549,7 @@
-
149 with experiment . start ():
+
148 with experiment . start ():
@@ -561,7 +561,7 @@
@@ -573,8 +573,8 @@
-
155 if __name__ == '__main__' :
-156 main ()
+
154 if __name__ == '__main__' :
+155 main ()
Multi-Headed Attention (MHA)
-
+
This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch . The implementation is inspired from Annotated Transformer .
Here is the training code that uses a basic transformer with MHA for NLP auto-regression.
Here is an experiment implementation that trains a simple transformer.
-
25 import math
-26 from typing import Optional , List
-27
-28 import torch
-29 from torch import nn
-30
-31 from labml import tracker
+
24 import math
+25 from typing import Optional , List
+26
+27 import torch
+28 from torch import nn
+29
+30 from labml import tracker
@@ -102,7 +102,7 @@
-
34 class PrepareForMultiHeadAttention ( nn . Module ):
+
33 class PrepareForMultiHeadAttention ( nn . Module ):
@@ -113,8 +113,8 @@
-
45 def __init__ ( self , d_model : int , heads : int , d_k : int , bias : bool ):
-46 super () . __init__ ()
+
44 def __init__ ( self , d_model : int , heads : int , d_k : int , bias : bool ):
+45 super () . __init__ ()
@@ -126,7 +126,7 @@
-
48 self . linear = nn . Linear ( d_model , heads * d_k , bias = bias )
+
47 self . linear = nn . Linear ( d_model , heads * d_k , bias = bias )
@@ -138,7 +138,7 @@
@@ -150,7 +150,7 @@
@@ -161,7 +161,7 @@
-
54 def forward ( self , x : torch . Tensor ):
+
53 def forward ( self , x : torch . Tensor ):
@@ -175,7 +175,7 @@
-
58 head_shape = x . shape [: - 1 ]
+
57 head_shape = x . shape [: - 1 ]
@@ -187,7 +187,7 @@
@@ -199,7 +199,7 @@
-
64 x = x . view ( * head_shape , self . heads , self . d_k )
+
63 x = x . view ( * head_shape , self . heads , self . d_k )
@@ -213,7 +213,7 @@
@@ -256,7 +256,7 @@ M834 80h400000v40h-400000z">
70 class MultiHeadAttention ( nn . Module ):
+ 69 class MultiHeadAttention ( nn . Module ):
@@ -274,7 +274,7 @@ M834 80h400000v40h-400000z">
91 def __init__ ( self , heads : int , d_model : int , dropout_prob : float = 0.1 , bias : bool = True ):
+ 90 def __init__ ( self , heads : int , d_model : int , dropout_prob : float = 0.1 , bias : bool = True ):
@@ -285,7 +285,7 @@ M834 80h400000v40h-400000z">
97 super () . __init__ ()
+
@@ -297,7 +297,7 @@ M834 80h400000v40h-400000z">
100 self . d_k = d_model // heads
+ 99 self . d_k = d_model // heads
@@ -309,7 +309,7 @@ M834 80h400000v40h-400000z">
102 self . heads = heads
+
@@ -324,9 +324,9 @@ M834 80h400000v40h-400000z">
105 self . query = PrepareForMultiHeadAttention ( d_model , heads , self . d_k , bias = bias )
-106 self . key = PrepareForMultiHeadAttention ( d_model , heads , self . d_k , bias = bias )
-107 self . value = PrepareForMultiHeadAttention ( d_model , heads , self . d_k , bias = True )
+ 104 self . query = PrepareForMultiHeadAttention ( d_model , heads , self . d_k , bias = bias )
+105 self . key = PrepareForMultiHeadAttention ( d_model , heads , self . d_k , bias = bias )
+106 self . value = PrepareForMultiHeadAttention ( d_model , heads , self . d_k , bias = True )
@@ -339,7 +339,7 @@ M834 80h400000v40h-400000z">
110 self . softmax = nn . Softmax ( dim = 1 )
+ 109 self . softmax = nn . Softmax ( dim = 1 )
@@ -351,7 +351,7 @@ M834 80h400000v40h-400000z">
113 self . output = nn . Linear ( d_model , d_model )
+ 112 self . output = nn . Linear ( d_model , d_model )
@@ -363,7 +363,7 @@ M834 80h400000v40h-400000z">
115 self . dropout = nn . Dropout ( dropout_prob )
+ 114 self . dropout = nn . Dropout ( dropout_prob )
@@ -375,7 +375,7 @@ M834 80h400000v40h-400000z">
117 self . scale = 1 / math . sqrt ( self . d_k )
+ 116 self . scale = 1 / math . sqrt ( self . d_k )
@@ -387,7 +387,7 @@ M834 80h400000v40h-400000z">
120 self . attn = None
+
@@ -400,7 +400,7 @@ M834 80h400000v40h-400000z">
122 def get_scores ( self , query : torch . Tensor , key : torch . Tensor ):
+ 121 def get_scores ( self , query : torch . Tensor , key : torch . Tensor ):
@@ -412,7 +412,7 @@ M834 80h400000v40h-400000z">
130 return torch . einsum ( 'ibhd,jbhd->ijbh' , query , key )
+ 129 return torch . einsum ( 'ibhd,jbhd->ijbh' , query , key )
@@ -426,7 +426,7 @@ M834 80h400000v40h-400000z">
132 def prepare_mask ( self , mask : torch . Tensor , query_shape : List [ int ], key_shape : List [ int ]):
+ 131 def prepare_mask ( self , mask : torch . Tensor , query_shape : List [ int ], key_shape : List [ int ]):
@@ -437,9 +437,9 @@ M834 80h400000v40h-400000z">
138 assert mask . shape [ 0 ] == 1 or mask . shape [ 0 ] == query_shape [ 0 ]
-139 assert mask . shape [ 1 ] == key_shape [ 0 ]
-140 assert mask . shape [ 2 ] == 1 or mask . shape [ 2 ] == query_shape [ 1 ]
+ 137 assert mask . shape [ 0 ] == 1 or mask . shape [ 0 ] == query_shape [ 0 ]
+138 assert mask . shape [ 1 ] == key_shape [ 0 ]
+139 assert mask . shape [ 2 ] == 1 or mask . shape [ 2 ] == query_shape [ 1 ]
@@ -451,7 +451,7 @@ M834 80h400000v40h-400000z">
143 mask = mask . unsqueeze ( - 1 )
+ 142 mask = mask . unsqueeze ( - 1 )
@@ -464,7 +464,7 @@ M834 80h400000v40h-400000z">
146 return mask
+
@@ -487,11 +487,11 @@ M834 80h400000v40h-400000z">
148 def forward ( self , * ,
-149 query : torch . Tensor ,
-150 key : torch . Tensor ,
-151 value : torch . Tensor ,
-152 mask : Optional [ torch . Tensor ] = None ):
+ 147 def forward ( self , * ,
+148 query : torch . Tensor ,
+149 key : torch . Tensor ,
+150 value : torch . Tensor ,
+151 mask : Optional [ torch . Tensor ] = None ):
@@ -507,10 +507,10 @@ M834 80h400000v40h-400000z">
164 seq_len , batch_size , _ = query . shape
-165
-166 if mask is not None :
-167 mask = self . prepare_mask ( mask , query . shape , key . shape )
+ 163 seq_len , batch_size , _ = query . shape
+164
+165 if mask is not None :
+166 mask = self . prepare_mask ( mask , query . shape , key . shape )
@@ -526,9 +526,9 @@ M834 80h400000v40h-400000z">
171 query = self . query ( query )
-172 key = self . key ( key )
-173 value = self . value ( value )
+ 170 query = self . query ( query )
+171 key = self . key ( key )
+172 value = self . value ( value )
@@ -541,7 +541,7 @@ M834 80h400000v40h-400000z">
177 scores = self . get_scores ( query , key )
+ 176 scores = self . get_scores ( query , key )
@@ -564,7 +564,7 @@ M834 80h400000v40h-400000z">
180 scores *= self . scale
+
@@ -576,8 +576,8 @@ M834 80h400000v40h-400000z">
183 if mask is not None :
-184 scores = scores . masked_fill ( mask == 0 , float ( '-inf' ))
+ 182 if mask is not None :
+183 scores = scores . masked_fill ( mask == 0 , float ( '-inf' ))
@@ -600,7 +600,7 @@ M834 80h400000v40h-400000z">
188 attn = self . softmax ( scores )
+ 187 attn = self . softmax ( scores )
@@ -612,7 +612,7 @@ M834 80h400000v40h-400000z">
191 tracker . debug ( 'attn' , attn )
+ 190 tracker . debug ( 'attn' , attn )
@@ -624,7 +624,7 @@ M834 80h400000v40h-400000z">
194 attn = self . dropout ( attn )
+ 193 attn = self . dropout ( attn )
@@ -647,7 +647,7 @@ M834 80h400000v40h-400000z">
198 x = torch . einsum ( "ijbh,jbhd->ibhd" , attn , value )
+ 197 x = torch . einsum ( "ijbh,jbhd->ibhd" , attn , value )
@@ -659,7 +659,7 @@ M834 80h400000v40h-400000z">
201 self . attn = attn . detach ()
+ 200 self . attn = attn . detach ()
@@ -671,7 +671,7 @@ M834 80h400000v40h-400000z">
204 x = x . reshape ( seq_len , batch_size , - 1 )
+ 203 x = x . reshape ( seq_len , batch_size , - 1 )
@@ -683,7 +683,7 @@ M834 80h400000v40h-400000z">
207 return self . output ( x )
+ 206 return self . output ( x )
Transformer Encoder and Decoder Models
-
+
-
14 import math
-15
-16 import torch
-17 import torch.nn as nn
-18
-19 from labml_nn.utils import clone_module_list
-20 from .feed_forward import FeedForward
-21 from .mha import MultiHeadAttention
-22 from .positional_encoding import get_positional_encoding
+
13 import math
+14
+15 import torch
+16 import torch.nn as nn
+17
+18 from labml_nn.utils import clone_module_list
+19 from .feed_forward import FeedForward
+20 from .mha import MultiHeadAttention
+21 from .positional_encoding import get_positional_encoding
@@ -100,7 +100,7 @@
-
25 class EmbeddingsWithPositionalEncoding ( nn . Module ):
+
24 class EmbeddingsWithPositionalEncoding ( nn . Module ):
@@ -111,11 +111,11 @@
-
32 def __init__ ( self , d_model : int , n_vocab : int , max_len : int = 5000 ):
-33 super () . __init__ ()
-34 self . linear = nn . Embedding ( n_vocab , d_model )
-35 self . d_model = d_model
-36 self . register_buffer ( 'positional_encodings' , get_positional_encoding ( d_model , max_len ))
+
31 def __init__ ( self , d_model : int , n_vocab : int , max_len : int = 5000 ):
+32 super () . __init__ ()
+33 self . linear = nn . Embedding ( n_vocab , d_model )
+34 self . d_model = d_model
+35 self . register_buffer ( 'positional_encodings' , get_positional_encoding ( d_model , max_len ))
@@ -126,9 +126,9 @@
-
38 def forward ( self , x : torch . Tensor ):
-39 pe = self . positional_encodings [: x . shape [ 0 ]] . requires_grad_ ( False )
-40 return self . linear ( x ) * math . sqrt ( self . d_model ) + pe
+
37 def forward ( self , x : torch . Tensor ):
+38 pe = self . positional_encodings [: x . shape [ 0 ]] . requires_grad_ ( False )
+39 return self . linear ( x ) * math . sqrt ( self . d_model ) + pe
@@ -141,7 +141,7 @@
-
43 class EmbeddingsWithLearnedPositionalEncoding ( nn . Module ):
+
42 class EmbeddingsWithLearnedPositionalEncoding ( nn . Module ):
@@ -152,11 +152,11 @@
-
50 def __init__ ( self , d_model : int , n_vocab : int , max_len : int = 5000 ):
-51 super () . __init__ ()
-52 self . linear = nn . Embedding ( n_vocab , d_model )
-53 self . d_model = d_model
-54 self . positional_encodings = nn . Parameter ( torch . zeros ( max_len , 1 , d_model ), requires_grad = True )
+
49 def __init__ ( self , d_model : int , n_vocab : int , max_len : int = 5000 ):
+50 super () . __init__ ()
+51 self . linear = nn . Embedding ( n_vocab , d_model )
+52 self . d_model = d_model
+53 self . positional_encodings = nn . Parameter ( torch . zeros ( max_len , 1 , d_model ), requires_grad = True )
@@ -167,9 +167,9 @@
-
56 def forward ( self , x : torch . Tensor ):
-57 pe = self . positional_encodings [: x . shape [ 0 ]]
-58 return self . linear ( x ) * math . sqrt ( self . d_model ) + pe
+
55 def forward ( self , x : torch . Tensor ):
+56 pe = self . positional_encodings [: x . shape [ 0 ]]
+57 return self . linear ( x ) * math . sqrt ( self . d_model ) + pe
@@ -184,7 +184,7 @@
-
61 class TransformerLayer ( nn . Module ):
+
60 class TransformerLayer ( nn . Module ):
@@ -205,12 +205,12 @@
-
79 def __init__ ( self , * ,
-80 d_model : int ,
-81 self_attn : MultiHeadAttention ,
-82 src_attn : MultiHeadAttention = None ,
-83 feed_forward : FeedForward ,
-84 dropout_prob : float ):
+
78 def __init__ ( self , * ,
+79 d_model : int ,
+80 self_attn : MultiHeadAttention ,
+81 src_attn : MultiHeadAttention = None ,
+82 feed_forward : FeedForward ,
+83 dropout_prob : float ):
@@ -221,16 +221,16 @@
-
92 super () . __init__ ()
-93 self . size = d_model
-94 self . self_attn = self_attn
-95 self . src_attn = src_attn
-96 self . feed_forward = feed_forward
-97 self . dropout = nn . Dropout ( dropout_prob )
-98 self . norm_self_attn = nn . LayerNorm ([ d_model ])
-99 if self . src_attn is not None :
-100 self . norm_src_attn = nn . LayerNorm ([ d_model ])
-101 self . norm_ff = nn . LayerNorm ([ d_model ])
+
91 super () . __init__ ()
+92 self . size = d_model
+93 self . self_attn = self_attn
+94 self . src_attn = src_attn
+95 self . feed_forward = feed_forward
+96 self . dropout = nn . Dropout ( dropout_prob )
+97 self . norm_self_attn = nn . LayerNorm ([ d_model ])
+98 if self . src_attn is not None :
+99 self . norm_src_attn = nn . LayerNorm ([ d_model ])
+100 self . norm_ff = nn . LayerNorm ([ d_model ])
@@ -242,7 +242,7 @@
-
103 self . is_save_ff_input = False
+
102 self . is_save_ff_input = False
@@ -253,11 +253,11 @@
-
105 def forward ( self , * ,
-106 x : torch . Tensor ,
-107 mask : torch . Tensor ,
-108 src : torch . Tensor = None ,
-109 src_mask : torch . Tensor = None ):
+
104 def forward ( self , * ,
+105 x : torch . Tensor ,
+106 mask : torch . Tensor ,
+107 src : torch . Tensor = None ,
+108 src_mask : torch . Tensor = None ):
@@ -269,7 +269,7 @@
-
111 z = self . norm_self_attn ( x )
+
110 z = self . norm_self_attn ( x )
@@ -281,7 +281,7 @@
-
113 self_attn = self . self_attn ( query = z , key = z , value = z , mask = mask )
+
112 self_attn = self . self_attn ( query = z , key = z , value = z , mask = mask )
@@ -293,7 +293,7 @@
-
115 x = x + self . dropout ( self_attn )
+
114 x = x + self . dropout ( self_attn )
@@ -305,7 +305,7 @@
@@ -317,7 +317,7 @@
-
122 z = self . norm_src_attn ( x )
+
121 z = self . norm_src_attn ( x )
@@ -329,7 +329,7 @@
-
124 attn_src = self . src_attn ( query = z , key = src , value = src , mask = src_mask )
+
123 attn_src = self . src_attn ( query = z , key = src , value = src , mask = src_mask )
@@ -341,7 +341,7 @@
-
126 x = x + self . dropout ( attn_src )
+
125 x = x + self . dropout ( attn_src )
@@ -353,7 +353,7 @@
@@ -365,8 +365,8 @@
-
131 if self . is_save_ff_input :
-132 self . ff_input = z . clone ()
+
130 if self . is_save_ff_input :
+131 self . ff_input = z . clone ()
@@ -378,7 +378,7 @@
-
134 ff = self . feed_forward ( z )
+
133 ff = self . feed_forward ( z )
@@ -390,9 +390,9 @@
-
136 x = x + self . dropout ( ff )
-137
-138 return x
+
135 x = x + self . dropout ( ff )
+136
+137 return x
@@ -405,7 +405,7 @@
-
141 class Encoder ( nn . Module ):
+
140 class Encoder ( nn . Module ):
@@ -416,8 +416,8 @@
-
148 def __init__ ( self , layer : TransformerLayer , n_layers : int ):
-149 super () . __init__ ()
+
147 def __init__ ( self , layer : TransformerLayer , n_layers : int ):
+148 super () . __init__ ()
@@ -429,7 +429,7 @@
-
151 self . layers = clone_module_list ( layer , n_layers )
+
150 self . layers = clone_module_list ( layer , n_layers )
@@ -441,7 +441,7 @@
-
153 self . norm = nn . LayerNorm ([ layer . size ])
+
152 self . norm = nn . LayerNorm ([ layer . size ])
@@ -452,7 +452,7 @@
-
155 def forward ( self , x : torch . Tensor , mask : torch . Tensor ):
+
154 def forward ( self , x : torch . Tensor , mask : torch . Tensor ):
@@ -464,8 +464,8 @@
-
157 for layer in self . layers :
-158 x = layer ( x = x , mask = mask )
+
156 for layer in self . layers :
+157 x = layer ( x = x , mask = mask )
@@ -477,7 +477,7 @@
@@ -490,7 +490,7 @@
-
163 class Decoder ( nn . Module ):
+
162 class Decoder ( nn . Module ):
@@ -501,8 +501,8 @@
-
170 def __init__ ( self , layer : TransformerLayer , n_layers : int ):
-171 super () . __init__ ()
+
169 def __init__ ( self , layer : TransformerLayer , n_layers : int ):
+170 super () . __init__ ()
@@ -514,7 +514,7 @@
-
173 self . layers = clone_module_list ( layer , n_layers )
+
172 self . layers = clone_module_list ( layer , n_layers )
@@ -526,7 +526,7 @@
-
175 self . norm = nn . LayerNorm ([ layer . size ])
+
174 self . norm = nn . LayerNorm ([ layer . size ])
@@ -537,7 +537,7 @@
-
177 def forward ( self , x : torch . Tensor , memory : torch . Tensor , src_mask : torch . Tensor , tgt_mask : torch . Tensor ):
+
176 def forward ( self , x : torch . Tensor , memory : torch . Tensor , src_mask : torch . Tensor , tgt_mask : torch . Tensor ):
@@ -549,8 +549,8 @@
-
179 for layer in self . layers :
-180 x = layer ( x = x , mask = tgt_mask , src = memory , src_mask = src_mask )
+
178 for layer in self . layers :
+179 x = layer ( x = x , mask = tgt_mask , src = memory , src_mask = src_mask )
@@ -562,7 +562,7 @@
@@ -577,7 +577,7 @@
-
185 class Generator ( nn . Module ):
+
184 class Generator ( nn . Module ):
@@ -588,9 +588,9 @@
-
195 def __init__ ( self , n_vocab : int , d_model : int ):
-196 super () . __init__ ()
-197 self . projection = nn . Linear ( d_model , n_vocab )
+
194 def __init__ ( self , n_vocab : int , d_model : int ):
+195 super () . __init__ ()
+196 self . projection = nn . Linear ( d_model , n_vocab )
@@ -601,8 +601,8 @@
-
199 def forward ( self , x ):
-200 return self . projection ( x )
+
198 def forward ( self , x ):
+199 return self . projection ( x )
@@ -615,7 +615,7 @@
-
203 class EncoderDecoder ( nn . Module ):
+
202 class EncoderDecoder ( nn . Module ):
@@ -626,13 +626,13 @@
-
210 def __init__ ( self , encoder : Encoder , decoder : Decoder , src_embed : nn . Module , tgt_embed : nn . Module , generator : nn . Module ):
-211 super () . __init__ ()
-212 self . encoder = encoder
-213 self . decoder = decoder
-214 self . src_embed = src_embed
-215 self . tgt_embed = tgt_embed
-216 self . generator = generator
+
209 def __init__ ( self , encoder : Encoder , decoder : Decoder , src_embed : nn . Module , tgt_embed : nn . Module , generator : nn . Module ):
+210 super () . __init__ ()
+211 self . encoder = encoder
+212 self . decoder = decoder
+213 self . src_embed = src_embed
+214 self . tgt_embed = tgt_embed
+215 self . generator = generator
@@ -644,9 +644,9 @@
-
220 for p in self . parameters ():
-221 if p . dim () > 1 :
-222 nn . init . xavier_uniform_ ( p )
+
219 for p in self . parameters ():
+220 if p . dim () > 1 :
+221 nn . init . xavier_uniform_ ( p )
@@ -657,7 +657,7 @@
-
224 def forward ( self , src : torch . Tensor , tgt : torch . Tensor , src_mask : torch . Tensor , tgt_mask : torch . Tensor ):
+
223 def forward ( self , src : torch . Tensor , tgt : torch . Tensor , src_mask : torch . Tensor , tgt_mask : torch . Tensor ):
@@ -669,7 +669,7 @@
-
226 enc = self . encode ( src , src_mask )
+
225 enc = self . encode ( src , src_mask )
@@ -681,7 +681,7 @@
-
228 return self . decode ( enc , src_mask , tgt , tgt_mask )
+
227 return self . decode ( enc , src_mask , tgt , tgt_mask )
@@ -692,8 +692,8 @@
-
230 def encode ( self , src : torch . Tensor , src_mask : torch . Tensor ):
-231 return self . encoder ( self . src_embed ( src ), src_mask )
+
229 def encode ( self , src : torch . Tensor , src_mask : torch . Tensor ):
+230 return self . encoder ( self . src_embed ( src ), src_mask )
@@ -704,8 +704,8 @@
-
233 def decode ( self , memory : torch . Tensor , src_mask : torch . Tensor , tgt : torch . Tensor , tgt_mask : torch . Tensor ):
-234 return self . decoder ( self . tgt_embed ( tgt ), memory , src_mask , tgt_mask )
+
232 def decode ( self , memory : torch . Tensor , src_mask : torch . Tensor , tgt : torch . Tensor , tgt_mask : torch . Tensor ):
+233 return self . decoder ( self . tgt_embed ( tgt ), memory , src_mask , tgt_mask )
-
45 experiment . create ( name = "roper_addition" , comment = "rotary value 7" , writers = { 'screen' , 'labml' , 'comet' })
+
45 experiment . create ( name = "roper_addition" , comment = "rotary value 7" , writers = { 'screen' , 'labml' })
diff --git a/docs/transformers/rope/value_pe/experiment.html b/docs/transformers/rope/value_pe/experiment.html
index a9fcde20..012d96c2 100644
--- a/docs/transformers/rope/value_pe/experiment.html
+++ b/docs/transformers/rope/value_pe/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/rope/value_pe/index.html b/docs/transformers/rope/value_pe/index.html
index 894aa595..c9cb99d6 100644
--- a/docs/transformers/rope/value_pe/index.html
+++ b/docs/transformers/rope/value_pe/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/switch/experiment.html b/docs/transformers/switch/experiment.html
index da78d737..4b9e4069 100644
--- a/docs/transformers/switch/experiment.html
+++ b/docs/transformers/switch/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/switch/index.html b/docs/transformers/switch/index.html
index 9562fc66..bb3f6e75 100644
--- a/docs/transformers/switch/index.html
+++ b/docs/transformers/switch/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/switch/readme.html b/docs/transformers/switch/readme.html
index 31dedc63..94b76e85 100644
--- a/docs/transformers/switch/readme.html
+++ b/docs/transformers/switch/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/utils.html b/docs/transformers/utils.html
index f03a78f4..5e2332ab 100644
--- a/docs/transformers/utils.html
+++ b/docs/transformers/utils.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/vit/experiment.html b/docs/transformers/vit/experiment.html
index 8e0ab9ed..c8128d6d 100644
--- a/docs/transformers/vit/experiment.html
+++ b/docs/transformers/vit/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/vit/index.html b/docs/transformers/vit/index.html
index 6696998f..0ebd074e 100644
--- a/docs/transformers/vit/index.html
+++ b/docs/transformers/vit/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/vit/readme.html b/docs/transformers/vit/readme.html
index a18dfa04..d5d52600 100644
--- a/docs/transformers/vit/readme.html
+++ b/docs/transformers/vit/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/xl/experiment.html b/docs/transformers/xl/experiment.html
index c34bcb20..08ff84dd 100644
--- a/docs/transformers/xl/experiment.html
+++ b/docs/transformers/xl/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/xl/index.html b/docs/transformers/xl/index.html
index 5d012e86..db13a79e 100644
--- a/docs/transformers/xl/index.html
+++ b/docs/transformers/xl/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/xl/readme.html b/docs/transformers/xl/readme.html
index 5174f2b9..402f14b7 100644
--- a/docs/transformers/xl/readme.html
+++ b/docs/transformers/xl/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/transformers/xl/relative_mha.html b/docs/transformers/xl/relative_mha.html
index 52dcd665..f617d7da 100644
--- a/docs/transformers/xl/relative_mha.html
+++ b/docs/transformers/xl/relative_mha.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/uncertainty/evidence/experiment.html b/docs/uncertainty/evidence/experiment.html
index 3a759daa..f52e0cba 100644
--- a/docs/uncertainty/evidence/experiment.html
+++ b/docs/uncertainty/evidence/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/uncertainty/evidence/index.html b/docs/uncertainty/evidence/index.html
index 2e340588..74fa8a00 100644
--- a/docs/uncertainty/evidence/index.html
+++ b/docs/uncertainty/evidence/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/uncertainty/evidence/readme.html b/docs/uncertainty/evidence/readme.html
index a4209cc2..96913440 100644
--- a/docs/uncertainty/evidence/readme.html
+++ b/docs/uncertainty/evidence/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/uncertainty/index.html b/docs/uncertainty/index.html
index 94af4c96..c885d6c2 100644
--- a/docs/uncertainty/index.html
+++ b/docs/uncertainty/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/uncertainty/readme.html b/docs/uncertainty/readme.html
index 4cb20dd5..cba06c9c 100644
--- a/docs/uncertainty/readme.html
+++ b/docs/uncertainty/readme.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/unet/carvana.html b/docs/unet/carvana.html
index b860d11c..83cff98e 100644
--- a/docs/unet/carvana.html
+++ b/docs/unet/carvana.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/unet/experiment.html b/docs/unet/experiment.html
index d65e2aba..26cfabe3 100644
--- a/docs/unet/experiment.html
+++ b/docs/unet/experiment.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/unet/index.html b/docs/unet/index.html
index 628b9b67..e3476026 100644
--- a/docs/unet/index.html
+++ b/docs/unet/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/utils/index.html b/docs/utils/index.html
index 5a2b4386..3f12a5a2 100644
--- a/docs/utils/index.html
+++ b/docs/utils/index.html
@@ -1,5 +1,5 @@
-
+
diff --git a/docs/utils/tokenizer.html b/docs/utils/tokenizer.html
index a473f647..286084d7 100644
--- a/docs/utils/tokenizer.html
+++ b/docs/utils/tokenizer.html
@@ -1,5 +1,5 @@
-
+
diff --git a/labml_nn/diffusion/ddpm/__init__.py b/labml_nn/diffusion/ddpm/__init__.py
index b678fb37..d78f6ad4 100644
--- a/labml_nn/diffusion/ddpm/__init__.py
+++ b/labml_nn/diffusion/ddpm/__init__.py
@@ -144,7 +144,7 @@ That is, we are training to predict the noise.
### Simplified loss
-$$L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
+$$L_{\text{simple}}(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
\epsilon - \textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\bigg\Vert^2 \Bigg]$$
@@ -265,7 +265,7 @@ class DenoiseDiffusion:
"""
#### Simplified Loss
- $$L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
+ $$L_{\text{simple}}(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
\epsilon - \textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\bigg\Vert^2 \Bigg]$$
"""