diff --git a/docs/normalization/layer_norm/index.html b/docs/normalization/layer_norm/index.html index ff2e5fdd..da2d4e13 100644 --- a/docs/normalization/layer_norm/index.html +++ b/docs/normalization/layer_norm/index.html @@ -91,14 +91,12 @@ across the features. *Note that batch normalization, fixes the zero mean and unit variance for each vector. Layer normalization does it for each batch across all elements.
Layer normalization is generally used for NLP tasks.
-Here’s the training code and a notebook for training -a CNN classifier that use batch normalization for MNIST dataset.
- +We have used layer normalization in most of the +transformer implementations.
39import torch
-40from torch import nn36import torch
+37from torch import nn43class BatchNorm(nn.Module):40class LayerNorm(nn.Module):We’ve tried to use the same names for arguments as PyTorch BatchNorm implementation.
48 def __init__(self, channels: int, *,
-49 eps: float = 1e-5, momentum: float = 0.1,
-50 affine: bool = True, track_running_stats: bool = True):45 def __init__(self, channels: int, *,
+46 eps: float = 1e-5, momentum: float = 0.1,
+47 affine: bool = True, track_running_stats: bool = True):60 super().__init__()
-61
-62 self.channels = channels
-63
-64 self.eps = eps
-65 self.momentum = momentum
-66 self.affine = affine
-67 self.track_running_stats = track_running_stats57 super().__init__()
+58
+59 self.channels = channels
+60
+61 self.eps = eps
+62 self.momentum = momentum
+63 self.affine = affine
+64 self.track_running_stats = track_running_statsCreate parameters for $\gamma$ and $\beta$ for scale and shift
69 if self.affine:
-70 self.scale = nn.Parameter(torch.ones(channels))
-71 self.shift = nn.Parameter(torch.zeros(channels))66 if self.affine:
+67 self.scale = nn.Parameter(torch.ones(channels))
+68 self.shift = nn.Parameter(torch.zeros(channels))74 if self.track_running_stats:
-75 self.register_buffer('exp_mean', torch.zeros(channels))
-76 self.register_buffer('exp_var', torch.ones(channels))71 if self.track_running_stats:
+72 self.register_buffer('exp_mean', torch.zeros(channels))
+73 self.register_buffer('exp_var', torch.ones(channels))[batch_size, channels, height, width]
78 def forward(self, x: torch.Tensor):75 def forward(self, x: torch.Tensor):Keep the original shape
86 x_shape = x.shape83 x_shape = x.shapeGet the batch size
88 batch_size = x_shape[0]85 batch_size = x_shape[0]Sanity check to make sure the number of features is same
90 assert self.channels == x.shape[1]87 assert self.channels == x.shape[1]Reshape into [batch_size, channels, n]
93 x = x.view(batch_size, self.channels, -1)90 x = x.view(batch_size, self.channels, -1)97 if self.training or not self.track_running_stats:94 if self.training or not self.track_running_stats:100 mean = x.mean(dim=[0, 2])97 mean = x.mean(dim=[0, 2])103 mean_x2 = (x ** 2).mean(dim=[0, 2])100 mean_x2 = (x ** 2).mean(dim=[0, 2])Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
105 var = mean_x2 - mean ** 2102 var = mean_x2 - mean ** 2Update exponential moving averages
108 if self.training and self.track_running_stats:
-109 self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
-110 self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var105 if self.training and self.track_running_stats:
+106 self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
+107 self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * varUse exponential moving averages as estimates
112 else:
-113 mean = self.exp_mean
-114 var = self.exp_var109 else:
+110 mean = self.exp_mean
+111 var = self.exp_var117 x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)114 x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)119 if self.affine:
-120 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)116 if self.affine:
+117 x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)Reshape to original and return
123 return x_norm.view(x_shape)120 return x_norm.view(x_shape)