group norm cleanup

This commit is contained in:
Varuna Jayasiri
2021-04-20 17:05:16 +05:30
parent 834668437a
commit e47e0a5c82
2 changed files with 11 additions and 7 deletions

View File

@ -56,14 +56,14 @@ class GroupNorm(Module):
assert self.channels == x.shape[1] assert self.channels == x.shape[1]
# Reshape into `[batch_size, channels, n]` # Reshape into `[batch_size, channels, n]`
x = x.view(batch_size, self.groups, self.channels // self.groups, -1) x = x.view(batch_size, self.groups, -1)
# Calculate the mean across first and last dimension; # Calculate the mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$ # i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
mean = x.mean(dim=[2, 3], keepdims=True) mean = x.mean(dim=[2], keepdims=True)
# Calculate the squared mean across first and last dimension; # Calculate the squared mean across first and last dimension;
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$ # i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
mean_x2 = (x ** 2).mean(dim=[2, 3], keepdims=True) mean_x2 = (x ** 2).mean(dim=[2], keepdims=True)
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$ # Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
var = mean_x2 - mean ** 2 var = mean_x2 - mean ** 2

View File

@ -40,19 +40,23 @@ class Model(Module):
return self.fc(x) return self.fc(x)
@option(CIFAR10Configs.model) class Configs(CIFAR10Configs):
def model(c: CIFAR10Configs): groups: int = 16
@option(Configs.model)
def model(c: Configs):
""" """
### Create model ### Create model
""" """
return Model().to(c.device) return Model(c.groups).to(c.device)
def main(): def main():
# Create experiment # Create experiment
experiment.create(name='cifar10', comment='group norm') experiment.create(name='cifar10', comment='group norm')
# Create configurations # Create configurations
conf = CIFAR10Configs() conf = Configs()
# Load configurations # Load configurations
experiment.configs(conf, { experiment.configs(conf, {
'optimizer.optimizer': 'Adam', 'optimizer.optimizer': 'Adam',