From af6b99a5f03abdb46cc71e53daa5a8cc9bb747ae Mon Sep 17 00:00:00 2001
From: Varuna Jayasiri 
Calculate the mean across first and last dimensions; -$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$
+ +140                mean = x.mean(dim=[0, 2])Calculate the squared mean across first and last dimensions; -$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$
+ +143                mean_x2 = (x ** 2).mean(dim=[0, 2])Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
+Variance for each feature + +
145                var = mean_x2 - mean ** 2146                var = mean_x2 - mean ** 2152                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
-153                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var153                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
+154                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var157        x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)158        x_norm = (x - self.exp_mean.view(1, -1, 1)) / torch.sqrt(self.exp_var + self.eps).view(1, -1, 1)162        if self.affine:
-163            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)163        if self.affine:
+164            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)Reshape to original and return
166        return x_norm.view(x_shape)167        return x_norm.view(x_shape)This is similar to Group Normalization but affine transform is done group wise.
169class ChannelNorm(Module):170class ChannelNorm(Module):176    def __init__(self, channels, groups,
-177                 eps: float = 1e-5, affine: bool = True):177    def __init__(self, channels, groups,
+178                 eps: float = 1e-5, affine: bool = True):184        super().__init__()
-185        self.channels = channels
-186        self.groups = groups
-187        self.eps = eps
-188        self.affine = affine185        super().__init__()
+186        self.channels = channels
+187        self.groups = groups
+188        self.eps = eps
+189        self.affine = affine193        if self.affine:
-194            self.scale = nn.Parameter(torch.ones(groups))
-195            self.shift = nn.Parameter(torch.zeros(groups))194        if self.affine:
+195            self.scale = nn.Parameter(torch.ones(groups))
+196            self.shift = nn.Parameter(torch.zeros(groups))[batch_size, channels, height, width]
             197    def __call__(self, x: torch.Tensor):198    def __call__(self, x: torch.Tensor):Keep the original shape
206        x_shape = x.shape207        x_shape = x.shapeGet the batch size
208        batch_size = x_shape[0]209        batch_size = x_shape[0]Sanity check to make sure the number of features is the same
210        assert self.channels == x.shape[1]211        assert self.channels == x.shape[1]Reshape into [batch_size, groups, n]
213        x = x.view(batch_size, self.groups, -1)214        x = x.view(batch_size, self.groups, -1)217        mean = x.mean(dim=[-1], keepdim=True)218        mean = x.mean(dim=[-1], keepdim=True)220        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)221        mean_x2 = (x ** 2).mean(dim=[-1], keepdim=True)223        var = mean_x2 - mean ** 2224        var = mean_x2 - mean ** 2228        x_norm = (x - mean) / torch.sqrt(var + self.eps)229        x_norm = (x - mean) / torch.sqrt(var + self.eps)232        if self.affine:
-233            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)233        if self.affine:
+234            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1)Reshape to original and return
236        return x_norm.view(x_shape)237        return x_norm.view(x_shape)