mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-31 02:39:16 +08:00
batch channel norm mathjax fix
This commit is contained in:
@ -136,12 +136,13 @@ class EstimatedBatchNorm(Module):
|
||||
# No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$
|
||||
with torch.no_grad():
|
||||
# Calculate the mean across first and last dimensions;
|
||||
# $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$
|
||||
# $$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$$
|
||||
mean = x.mean(dim=[0, 2])
|
||||
# Calculate the squared mean across first and last dimensions;
|
||||
# $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$
|
||||
# $$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$$
|
||||
mean_x2 = (x ** 2).mean(dim=[0, 2])
|
||||
# Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
|
||||
# Variance for each feature
|
||||
# $$\frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2$$
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
# Update exponential moving averages
|
||||
|
||||
Reference in New Issue
Block a user