Zero3 memory optimizations (#140)

This commit is contained in:
Varuna Jayasiri
2022-08-11 15:44:13 +05:30
committed by GitHub
parent 0bfb210671
commit 980a84ed4f
43 changed files with 12573 additions and 9 deletions

View File

@ -76,7 +76,7 @@
<a href='#section-0'>#</a>
</div>
<h1>Train a <a href="index.html">Vision Transformer (ViT)</a> on CIFAR 10</h1>
<p><a href="https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
<p><a href="https://app.labml.ai/run/afdd5332188b11edbdf543360515b595"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>
@ -305,7 +305,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">1000</span><span class="p">,</span>
<div class="highlight"><pre><span class="lineno">79</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span>
<span class="lineno">80</span> <span class="s1">&#39;train_batch_size&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span></pre></div>
</div>
</div>

View File

@ -83,7 +83,7 @@
<p>When feeding the transformer with the patches, learned positional embeddings are added to the patch embeddings, because the patch embeddings do not have any information about where that patch is from. The positional embeddings are a set of vectors for each patch location that get trained with gradient descent along with other parameters.</p>
<p>ViTs perform well when they are pre-trained on large datasets. The paper suggests pre-training them with an MLP classification head and then using a single linear layer when fine-tuning. The paper beats SOTA with a ViT pre-trained on a 300 million image dataset. They also use higher resolution images during inference while keeping the patch size the same. The positional embeddings for new patch locations are calculated by interpolating learning positional embeddings.</p>
<p>Here&#x27;s <a href="experiment.html">an experiment</a> that trains ViT on CIFAR-10. This doesn&#x27;t do very well because it&#x27;s trained on a small dataset. It&#x27;s a simple experiment that anyone can run and play with ViTs.</p>
<p><a href="https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
<p><a href="https://app.labml.ai/run/afdd5332188b11edbdf543360515b595"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>
@ -289,7 +289,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span></pre></div>
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>