From e47e0a5c829ebdfd65b47f026b375de556be2390 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 20 Apr 2021 17:05:16 +0530 Subject: [PATCH] group norm cleanup --- labml_nn/normalization/group_norm/__init__.py | 6 +++--- labml_nn/normalization/group_norm/experiment.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/labml_nn/normalization/group_norm/__init__.py b/labml_nn/normalization/group_norm/__init__.py index 0e6f6e70..6fee2e3c 100644 --- a/labml_nn/normalization/group_norm/__init__.py +++ b/labml_nn/normalization/group_norm/__init__.py @@ -56,14 +56,14 @@ class GroupNorm(Module): assert self.channels == x.shape[1] # Reshape into `[batch_size, channels, n]` - x = x.view(batch_size, self.groups, self.channels // self.groups, -1) + x = x.view(batch_size, self.groups, -1) # Calculate the mean across first and last dimension; # i.e. the means for each feature $\mathbb{E}[x^{(k)}]$ - mean = x.mean(dim=[2, 3], keepdims=True) + mean = x.mean(dim=[2], keepdims=True) # Calculate the squared mean across first and last dimension; # i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$ - mean_x2 = (x ** 2).mean(dim=[2, 3], keepdims=True) + mean_x2 = (x ** 2).mean(dim=[2], keepdims=True) # Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$ var = mean_x2 - mean ** 2 diff --git a/labml_nn/normalization/group_norm/experiment.py b/labml_nn/normalization/group_norm/experiment.py index 6646db5f..8091bccb 100644 --- a/labml_nn/normalization/group_norm/experiment.py +++ b/labml_nn/normalization/group_norm/experiment.py @@ -40,19 +40,23 @@ class Model(Module): return self.fc(x) -@option(CIFAR10Configs.model) -def model(c: CIFAR10Configs): +class Configs(CIFAR10Configs): + groups: int = 16 + + +@option(Configs.model) +def model(c: Configs): """ ### Create model """ - return Model().to(c.device) + return Model(c.groups).to(c.device) def main(): # Create experiment experiment.create(name='cifar10', comment='group norm') # Create configurations - conf = CIFAR10Configs() + conf = Configs() # Load configurations experiment.configs(conf, { 'optimizer.optimizer': 'Adam',