From af6b99a5f03abdb46cc71e53daa5a8cc9bb747ae Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sun, 4 Jul 2021 13:32:56 +0530 Subject: [PATCH] batch channel norm mathjax fix --- .../batch_channel_norm/index.html | 70 ++++++++++--------- .../batch_channel_norm/__init__.py | 7 +- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/docs/normalization/batch_channel_norm/index.html b/docs/normalization/batch_channel_norm/index.html index 8f3f8316..0118df70 100644 --- a/docs/normalization/batch_channel_norm/index.html +++ b/docs/normalization/batch_channel_norm/index.html @@ -344,7 +344,8 @@ $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.

#

Calculate the mean across first and last dimensions; -$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$

+ +

140                mean = x.mean(dim=[0, 2])
@@ -356,7 +357,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$

#

Calculate the squared mean across first and last dimensions; -$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

+ +

143                mean_x2 = (x ** 2).mean(dim=[0, 2])
@@ -367,10 +369,12 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

-

Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2

+

Variance for each feature + +

-
145                var = mean_x2 - mean ** 2
+
146                var = mean_x2 - mean ** 2
@@ -386,8 +390,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

-
152                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
-153                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
+
153                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
+154                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
@@ -400,7 +404,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

-
157        x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)
+
158        x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)
@@ -415,8 +419,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

-
162        if self.affine:
-163            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
163        if self.affine:
+164            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
@@ -427,7 +431,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

Reshape to original and return

-
166        return x_norm.view(x_shape)
+
167        return x_norm.view(x_shape)
@@ -439,7 +443,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

This is similar to Group Normalization but affine transform is done group wise.

-
169class ChannelNorm(Module):
+
170class ChannelNorm(Module):
@@ -455,8 +459,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

-
176    def __init__(self, channels, groups,
-177                 eps: float = 1e-5, affine: bool = True):
+
177    def __init__(self, channels, groups,
+178                 eps: float = 1e-5, affine: bool = True):
@@ -467,11 +471,11 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

-
184        super().__init__()
-185        self.channels = channels
-186        self.groups = groups
-187        self.eps = eps
-188        self.affine = affine
+
185        super().__init__()
+186        self.channels = channels
+187        self.groups = groups
+188        self.eps = eps
+189        self.affine = affine
@@ -484,9 +488,9 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$

they are transformed channel-wise.

-
193        if self.affine:
-194            self.scale = nn.Parameter(torch.ones(groups))
-195            self.shift = nn.Parameter(torch.zeros(groups))
+
194        if self.affine:
+195            self.scale = nn.Parameter(torch.ones(groups))
+196            self.shift = nn.Parameter(torch.zeros(groups))
@@ -500,7 +504,7 @@ they are transformed channel-wise.

[batch_size, channels, height, width]

-
197    def __call__(self, x: torch.Tensor):
+
198    def __call__(self, x: torch.Tensor):
@@ -511,7 +515,7 @@ they are transformed channel-wise.

Keep the original shape

-
206        x_shape = x.shape
+
207        x_shape = x.shape
@@ -522,7 +526,7 @@ they are transformed channel-wise.

Get the batch size

-
208        batch_size = x_shape[0]
+
209        batch_size = x_shape[0]
@@ -533,7 +537,7 @@ they are transformed channel-wise.

Sanity check to make sure the number of features is the same

-
210        assert self.channels == x.shape[1]
+
211        assert self.channels == x.shape[1]
@@ -544,7 +548,7 @@ they are transformed channel-wise.

Reshape into [batch_size, groups, n]

-
213        x = x.view(batch_size, self.groups, -1)
+
214        x = x.view(batch_size, self.groups, -1)
@@ -556,7 +560,7 @@ they are transformed channel-wise.

i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$

-
217        mean = x.mean(dim=[-1], keepdim=True)
+
218        mean = x.mean(dim=[-1], keepdim=True)
@@ -568,7 +572,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$

-
220        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
+
221        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)
@@ -580,7 +584,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$< $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$

-
223        var = mean_x2 - mean ** 2
+
224        var = mean_x2 - mean ** 2
@@ -594,7 +598,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]

-
228        x_norm = (x - mean) / torch.sqrt(var + self.eps)
+
229        x_norm = (x - mean) / torch.sqrt(var + self.eps)
@@ -607,8 +611,8 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]

-
232        if self.affine:
-233            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
+
233        if self.affine:
+234            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)
@@ -619,7 +623,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]

Reshape to original and return

-
236        return x_norm.view(x_shape)
+
237        return x_norm.view(x_shape)
diff --git a/labml_nn/normalization/batch_channel_norm/__init__.py b/labml_nn/normalization/batch_channel_norm/__init__.py index 7c3c44ea..66714424 100644 --- a/labml_nn/normalization/batch_channel_norm/__init__.py +++ b/labml_nn/normalization/batch_channel_norm/__init__.py @@ -136,12 +136,13 @@ class EstimatedBatchNorm(Module): # No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ with torch.no_grad(): # Calculate the mean across first and last dimensions; - # $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$ + # $$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$$ mean = x.mean(dim=[0, 2]) # Calculate the squared mean across first and last dimensions; - # $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$ + # $$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$$ mean_x2 = (x ** 2).mean(dim=[0, 2]) - # Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2 + # Variance for each feature + # $$\frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2$$ var = mean_x2 - mean ** 2 # Update exponential moving averages