mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	batch channel norm mathjax fix
This commit is contained in:
		| @ -344,7 +344,8 @@ $\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.</p> | |||||||
|                     <a href='#section-19'>#</a> |                     <a href='#section-19'>#</a> | ||||||
|                 </div> |                 </div> | ||||||
|                 <p>Calculate the mean across first and last dimensions; |                 <p>Calculate the mean across first and last dimensions; | ||||||
| $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$</p> | <script type="math/tex; mode=display">\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}</script> | ||||||
|  | </p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">140</span>                <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">140</span>                <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div> | ||||||
| @ -356,7 +357,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$</p> | |||||||
|                     <a href='#section-20'>#</a> |                     <a href='#section-20'>#</a> | ||||||
|                 </div> |                 </div> | ||||||
|                 <p>Calculate the squared mean across first and last dimensions; |                 <p>Calculate the squared mean across first and last dimensions; | ||||||
| $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | <script type="math/tex; mode=display">\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}</script> | ||||||
|  | </p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">143</span>                <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">143</span>                <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span></pre></div> | ||||||
| @ -367,10 +369,12 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | |||||||
|                 <div class='section-link'> |                 <div class='section-link'> | ||||||
|                     <a href='#section-21'>#</a> |                     <a href='#section-21'>#</a> | ||||||
|                 </div> |                 </div> | ||||||
|                 <p>Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2</p> |                 <p>Variance for each feature | ||||||
|  | <script type="math/tex; mode=display">\frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2</script> | ||||||
|  | </p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">145</span>                <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">146</span>                <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-22'> |     <div class='section' id='section-22'> | ||||||
| @ -386,8 +390,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | |||||||
| </p> | </p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">152</span>                <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span> |                 <div class="highlight"><pre><span class="lineno">153</span>                <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span> | ||||||
| <span class="lineno">153</span>                <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div> | <span class="lineno">154</span>                <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-23'> |     <div class='section' id='section-23'> | ||||||
| @ -400,7 +404,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | |||||||
| </p> | </p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">157</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">158</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">exp_mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">exp_var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-24'> |     <div class='section' id='section-24'> | ||||||
| @ -415,8 +419,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | |||||||
| </p> | </p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">162</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> |                 <div class="highlight"><pre><span class="lineno">163</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> | ||||||
| <span class="lineno">163</span>            <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | <span class="lineno">164</span>            <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-25'> |     <div class='section' id='section-25'> | ||||||
| @ -427,7 +431,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | |||||||
|                 <p>Reshape to original and return</p> |                 <p>Reshape to original and return</p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">166</span>        <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">167</span>        <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-26'> |     <div class='section' id='section-26'> | ||||||
| @ -439,7 +443,7 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | |||||||
| <p>This is similar to <a href="../group_norm/index.html">Group Normalization</a> but affine transform is done group wise.</p> | <p>This is similar to <a href="../group_norm/index.html">Group Normalization</a> but affine transform is done group wise.</p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">169</span><span class="k">class</span> <span class="nc">ChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">170</span><span class="k">class</span> <span class="nc">ChannelNorm</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-27'> |     <div class='section' id='section-27'> | ||||||
| @ -455,8 +459,8 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | |||||||
| </ul> | </ul> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">176</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> |                 <div class="highlight"><pre><span class="lineno">177</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> | ||||||
| <span class="lineno">177</span>                 <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div> | <span class="lineno">178</span>                 <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-28'> |     <div class='section' id='section-28'> | ||||||
| @ -467,11 +471,11 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | |||||||
|                  |                  | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">184</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> |                 <div class="highlight"><pre><span class="lineno">185</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | ||||||
| <span class="lineno">185</span>        <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span> | <span class="lineno">186</span>        <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">=</span> <span class="n">channels</span> | ||||||
| <span class="lineno">186</span>        <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span> | <span class="lineno">187</span>        <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span> | ||||||
| <span class="lineno">187</span>        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span> | <span class="lineno">188</span>        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span> | ||||||
| <span class="lineno">188</span>        <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div> | <span class="lineno">189</span>        <span class="bp">self</span><span class="o">.</span><span class="n">affine</span> <span class="o">=</span> <span class="n">affine</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-29'> |     <div class='section' id='section-29'> | ||||||
| @ -484,9 +488,9 @@ $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$</p> | |||||||
| they are transformed channel-wise.</em></p> | they are transformed channel-wise.</em></p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">193</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> |                 <div class="highlight"><pre><span class="lineno">194</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> | ||||||
| <span class="lineno">194</span>            <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span> | <span class="lineno">195</span>            <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span> | ||||||
| <span class="lineno">195</span>            <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span></pre></div> | <span class="lineno">196</span>            <span class="bp">self</span><span class="o">.</span><span class="n">shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">groups</span><span class="p">))</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-30'> |     <div class='section' id='section-30'> | ||||||
| @ -500,7 +504,7 @@ they are transformed channel-wise.</em></p> | |||||||
| <code>[batch_size, channels, height, width]</code></p> | <code>[batch_size, channels, height, width]</code></p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">197</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">198</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-31'> |     <div class='section' id='section-31'> | ||||||
| @ -511,7 +515,7 @@ they are transformed channel-wise.</em></p> | |||||||
|                 <p>Keep the original shape</p> |                 <p>Keep the original shape</p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">206</span>        <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">207</span>        <span class="n">x_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-32'> |     <div class='section' id='section-32'> | ||||||
| @ -522,7 +526,7 @@ they are transformed channel-wise.</em></p> | |||||||
|                 <p>Get the batch size</p> |                 <p>Get the batch size</p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">208</span>        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">209</span>        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-33'> |     <div class='section' id='section-33'> | ||||||
| @ -533,7 +537,7 @@ they are transformed channel-wise.</em></p> | |||||||
|                 <p>Sanity check to make sure the number of features is the same</p> |                 <p>Sanity check to make sure the number of features is the same</p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">210</span>        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">211</span>        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-34'> |     <div class='section' id='section-34'> | ||||||
| @ -544,7 +548,7 @@ they are transformed channel-wise.</em></p> | |||||||
|                 <p>Reshape into <code>[batch_size, groups, n]</code></p> |                 <p>Reshape into <code>[batch_size, groups, n]</code></p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">213</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">214</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-35'> |     <div class='section' id='section-35'> | ||||||
| @ -556,7 +560,7 @@ they are transformed channel-wise.</em></p> | |||||||
| i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p> | i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">217</span>        <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">218</span>        <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-36'> |     <div class='section' id='section-36'> | ||||||
| @ -568,7 +572,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x_{(i_N, i_G)}]$</p | |||||||
| i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p> | i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$</p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">220</span>        <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">221</span>        <span class="n">mean_x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-37'> |     <div class='section' id='section-37'> | ||||||
| @ -580,7 +584,7 @@ i.e. the means for each sample and channel group $\mathbb{E}[x^2_{(i_N, i_G)}]$< | |||||||
| $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p> | $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}]^2$</p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">223</span>        <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">224</span>        <span class="n">var</span> <span class="o">=</span> <span class="n">mean_x2</span> <span class="o">-</span> <span class="n">mean</span> <span class="o">**</span> <span class="mi">2</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-38'> |     <div class='section' id='section-38'> | ||||||
| @ -594,7 +598,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}] | |||||||
| </p> | </p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">228</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">229</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-39'> |     <div class='section' id='section-39'> | ||||||
| @ -607,8 +611,8 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}] | |||||||
| </p> | </p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">232</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> |                 <div class="highlight"><pre><span class="lineno">233</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">affine</span><span class="p">:</span> | ||||||
| <span class="lineno">233</span>            <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | <span class="lineno">234</span>            <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     <div class='section' id='section-40'> |     <div class='section' id='section-40'> | ||||||
| @ -619,7 +623,7 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}] | |||||||
|                 <p>Reshape to original and return</p> |                 <p>Reshape to original and return</p> | ||||||
|             </div> |             </div> | ||||||
|             <div class='code'> |             <div class='code'> | ||||||
|                 <div class="highlight"><pre><span class="lineno">236</span>        <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div> |                 <div class="highlight"><pre><span class="lineno">237</span>        <span class="k">return</span> <span class="n">x_norm</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x_shape</span><span class="p">)</span></pre></div> | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
|  | |||||||
| @ -136,12 +136,13 @@ class EstimatedBatchNorm(Module): | |||||||
|             # No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ |             # No backpropagation through $\hat{\mu}_C$ and $\hat{\sigma}^2_C$ | ||||||
|             with torch.no_grad(): |             with torch.no_grad(): | ||||||
|                 # Calculate the mean across first and last dimensions; |                 # Calculate the mean across first and last dimensions; | ||||||
|                 # $\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$ |                 # $$\frac{1}{B H W} \sum_{b,h,w} X_{b,c,h,w}$$ | ||||||
|                 mean = x.mean(dim=[0, 2]) |                 mean = x.mean(dim=[0, 2]) | ||||||
|                 # Calculate the squared mean across first and last dimensions; |                 # Calculate the squared mean across first and last dimensions; | ||||||
|                 # $\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$ |                 # $$\frac{1}{B H W} \sum_{b,h,w} X^2_{b,c,h,w}$$ | ||||||
|                 mean_x2 = (x ** 2).mean(dim=[0, 2]) |                 mean_x2 = (x ** 2).mean(dim=[0, 2]) | ||||||
|                 # Variance for each feature \frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2 |                 # Variance for each feature | ||||||
|  |                 # $$\frac{1}{B H W} \sum_{b,h,w} \big(X_{b,c,h,w} - \hat{\mu}_C \big)^2$$ | ||||||
|                 var = mean_x2 - mean ** 2 |                 var = mean_x2 - mean ** 2 | ||||||
|  |  | ||||||
|                 # Update exponential moving averages |                 # Update exponential moving averages | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri