mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 10:11:39 +08:00
keepdims -> keepdim
This commit is contained in:
@ -60,10 +60,10 @@ class GroupNorm(Module):
|
||||
|
||||
# Calculate the mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
|
||||
mean = x.mean(dim=[2], keepdims=True)
|
||||
mean = x.mean(dim=[2], keepdim=True)
|
||||
# Calculate the squared mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
|
||||
mean_x2 = (x ** 2).mean(dim=[2], keepdims=True)
|
||||
mean_x2 = (x ** 2).mean(dim=[2], keepdim=True)
|
||||
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
|
@ -57,10 +57,10 @@ class InstanceNorm(Module):
|
||||
|
||||
# Calculate the mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[x^{(k)}]$
|
||||
mean = x.mean(dim=[2], keepdims=True)
|
||||
mean = x.mean(dim=[2], keepdim=True)
|
||||
# Calculate the squared mean across first and last dimension;
|
||||
# i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$
|
||||
mean_x2 = (x ** 2).mean(dim=[2], keepdims=True)
|
||||
mean_x2 = (x ** 2).mean(dim=[2], keepdim=True)
|
||||
# Variance for each feature $Var[x^{(k)}] = \mathbb{E}[(x^{(k)})^2] - \mathbb{E}[x^{(k)}]^2$
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
|
@ -106,10 +106,10 @@ class LayerNorm(Module):
|
||||
|
||||
# Calculate the mean of all elements;
|
||||
# i.e. the means for each element $\mathbb{E}[X]$
|
||||
mean = x.mean(dim=dims, keepdims=True)
|
||||
mean = x.mean(dim=dims, keepdim=True)
|
||||
# Calculate the squared mean of all elements;
|
||||
# i.e. the means for each element $\mathbb{E}[X^2]$
|
||||
mean_x2 = (x ** 2).mean(dim=dims, keepdims=True)
|
||||
mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)
|
||||
# Variance of all element $Var[X] = \mathbb{E}[X^2] - \mathbb{E}[X]^2$
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
|
Reference in New Issue
Block a user