mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 13:00:17 +08:00
group norm cleanup
This commit is contained in:
@ -56,14 +56,14 @@ class GroupNorm(Module):
|
||||
assert self.channels == x.shape[1]
|
||||
|
||||
# Reshape into `[batch_size, channels, n]`
|
||||
x = x.view(batch_size, self.groups, self.channels // self.groups, -1)
|
||||
x = x.view(batch_size, self.groups, -1)
|
||||
|
||||
# Calculate the mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
|
||||
mean = x.mean(dim=[2, 3], keepdims=True)
|
||||
mean = x.mean(dim=[2], keepdims=True)
|
||||
# Calculate the squared mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
|
||||
mean_x2 = (x ** 2).mean(dim=[2, 3], keepdims=True)
|
||||
mean_x2 = (x ** 2).mean(dim=[2], keepdims=True)
|
||||
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
|
||||
@ -40,19 +40,23 @@ class Model(Module):
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
@option(CIFAR10Configs.model)
|
||||
def model(c: CIFAR10Configs):
|
||||
class Configs(CIFAR10Configs):
|
||||
groups: int = 16
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def model(c: Configs):
|
||||
"""
|
||||
### Create model
|
||||
"""
|
||||
return Model().to(c.device)
|
||||
return Model(c.groups).to(c.device)
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='cifar10', comment='group norm')
|
||||
# Create configurations
|
||||
conf = CIFAR10Configs()
|
||||
conf = Configs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {
|
||||
'optimizer.optimizer': 'Adam',
|
||||
|
||||
Reference in New Issue
Block a user