batch channel norm mathjax fix

This commit is contained in:
Varuna Jayasiri
2021-07-04 13:32:56 +05:30
parent f00d1d61f7
commit af6b99a5f0
2 changed files with 41 additions and 36 deletions

View File

@ -136,12 +136,13 @@ class EstimatedBatchNorm(Module):
# No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$
with torch.no_grad():
# Calculate the mean across first and last dimensions;
# $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$
# $$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$$
mean = x.mean(dim=[0, 2])
# Calculate the squared mean across first and last dimensions;
# $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$
# $$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$$
mean_x2 = (x ** 2).mean(dim=[0, 2])
# Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2
# Variance for each feature
# $$\frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2$$
var = mean_x2 - mean ** 2
# Update exponential moving averages