group norm cleanup

This commit is contained in:
Varuna Jayasiri
2021-04-20 17:05:16 +05:30
parent 834668437a
commit e47e0a5c82
2 changed files with 11 additions and 7 deletions

View File

@ -56,14 +56,14 @@ class GroupNorm(Module):
assert self.channels == x.shape[1]
# Reshape into `[batch_size, channels, n]`
x = x.view(batch_size, self.groups, self.channels // self.groups, -1)
x = x.view(batch_size, self.groups, -1)
# Calculate the mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
mean = x.mean(dim=[2, 3], keepdims=True)
mean = x.mean(dim=[2], keepdims=True)
# Calculate the squared mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
mean_x2 = (x ** 2).mean(dim=[2, 3], keepdims=True)
mean_x2 = (x ** 2).mean(dim=[2], keepdims=True)
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
var = mean_x2 - mean ** 2

View File

@ -40,19 +40,23 @@ class Model(Module):
return self.fc(x)
@option(CIFAR10Configs.model)
def model(c: CIFAR10Configs):
class Configs(CIFAR10Configs):
groups: int = 16
@option(Configs.model)
def model(c: Configs):
"""
### Create model
"""
return Model().to(c.device)
return Model(c.groups).to(c.device)
def main():
# Create experiment
experiment.create(name='cifar10', comment='group norm')
# Create configurations
conf = CIFAR10Configs()
conf = Configs()
# Load configurations
experiment.configs(conf, {
'optimizer.optimizer': 'Adam',