diff --git a/labml_nn/normalization/group_norm/__init__.py b/labml_nn/normalization/group_norm/__init__.py index 6fee2e3c..00ff3f47 100644 --- a/labml_nn/normalization/group_norm/__init__.py +++ b/labml_nn/normalization/group_norm/__init__.py @@ -60,10 +60,10 @@ class GroupNorm(Module): # Calculate the mean across first and last dimension; # i.e. the means for each feature $\mathbb{E}[x^{(k)}]$ - mean = x.mean(dim=[2], keepdims=True) + mean = x.mean(dim=[2], keepdim=True) # Calculate the squared mean across first and last dimension; # i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$ - mean_x2 = (x ** 2).mean(dim=[2], keepdims=True) + mean_x2 = (x ** 2).mean(dim=[2], keepdim=True) # Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$ var = mean_x2 - mean ** 2 diff --git a/labml_nn/normalization/instance_norm/__init__.py b/labml_nn/normalization/instance_norm/__init__.py index 11c4ec30..d846c5e6 100644 --- a/labml_nn/normalization/instance_norm/__init__.py +++ b/labml_nn/normalization/instance_norm/__init__.py @@ -57,10 +57,10 @@ class InstanceNorm(Module): # Calculate the mean across first and last dimension; # i.e. the means for each feature $\mathbb{E}[x^{(k)}]$ - mean = x.mean(dim=[2], keepdims=True) + mean = x.mean(dim=[2], keepdim=True) # Calculate the squared mean across first and last dimension; # i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$ - mean_x2 = (x ** 2).mean(dim=[2], keepdims=True) + mean_x2 = (x ** 2).mean(dim=[2], keepdim=True) # Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$ var = mean_x2 - mean ** 2 diff --git a/labml_nn/normalization/layer_norm/__init__.py b/labml_nn/normalization/layer_norm/__init__.py index 88f88e45..e4bcabc6 100644 --- a/labml_nn/normalization/layer_norm/__init__.py +++ b/labml_nn/normalization/layer_norm/__init__.py @@ -106,10 +106,10 @@ class LayerNorm(Module): # Calculate the mean of all elements; # i.e. the means for each element $\mathbb{E}[X]$ - mean = x.mean(dim=dims, keepdims=True) + mean = x.mean(dim=dims, keepdim=True) # Calculate the squared mean of all elements; # i.e. the means for each element $\mathbb{E}[X^2]$ - mean_x2 = (x ** 2).mean(dim=dims, keepdims=True) + mean_x2 = (x ** 2).mean(dim=dims, keepdim=True) # Variance of all element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$ var = mean_x2 - mean ** 2