mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-06 23:57:19 +08:00
group norm cleanup
This commit is contained in:
@ -56,14 +56,14 @@ class GroupNorm(Module):
|
|||||||
assert self.channels == x.shape[1]
|
assert self.channels == x.shape[1]
|
||||||
|
|
||||||
# Reshape into `[batch_size, channels, n]`
|
# Reshape into `[batch_size, channels, n]`
|
||||||
x = x.view(batch_size, self.groups, self.channels // self.groups, -1)
|
x = x.view(batch_size, self.groups, -1)
|
||||||
|
|
||||||
# Calculate the mean across first and last dimension;
|
# Calculate the mean across first and last dimension;
|
||||||
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
|
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
|
||||||
mean = x.mean(dim=[2, 3], keepdims=True)
|
mean = x.mean(dim=[2], keepdims=True)
|
||||||
# Calculate the squared mean across first and last dimension;
|
# Calculate the squared mean across first and last dimension;
|
||||||
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
|
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
|
||||||
mean_x2 = (x ** 2).mean(dim=[2, 3], keepdims=True)
|
mean_x2 = (x ** 2).mean(dim=[2], keepdims=True)
|
||||||
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
|
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
|
||||||
var = mean_x2 - mean ** 2
|
var = mean_x2 - mean ** 2
|
||||||
|
|
||||||
|
|||||||
@ -40,19 +40,23 @@ class Model(Module):
|
|||||||
return self.fc(x)
|
return self.fc(x)
|
||||||
|
|
||||||
|
|
||||||
@option(CIFAR10Configs.model)
|
class Configs(CIFAR10Configs):
|
||||||
def model(c: CIFAR10Configs):
|
groups: int = 16
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.model)
|
||||||
|
def model(c: Configs):
|
||||||
"""
|
"""
|
||||||
### Create model
|
### Create model
|
||||||
"""
|
"""
|
||||||
return Model().to(c.device)
|
return Model(c.groups).to(c.device)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create experiment
|
# Create experiment
|
||||||
experiment.create(name='cifar10', comment='group norm')
|
experiment.create(name='cifar10', comment='group norm')
|
||||||
# Create configurations
|
# Create configurations
|
||||||
conf = CIFAR10Configs()
|
conf = Configs()
|
||||||
# Load configurations
|
# Load configurations
|
||||||
experiment.configs(conf, {
|
experiment.configs(conf, {
|
||||||
'optimizer.optimizer': 'Adam',
|
'optimizer.optimizer': 'Adam',
|
||||||
|
|||||||
Reference in New Issue
Block a user