Files
Varuna Jayasiri f1fe7087f1 footer
2021-08-19 15:21:18 +05:30
..
2021-08-19 15:21:18 +05:30
2021-08-19 15:21:18 +05:30
2021-08-19 15:21:18 +05:30

<!DOCTYPE html>
<html>
<head>
    <meta http-equiv="content-type" content="text/html;charset=utf-8"/>
    <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
    <meta name="description" content=""/>

    <meta name="twitter:card" content="summary"/>
    <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta name="twitter:title" content=" Vision Transformer (ViT)"/>
    <meta name="twitter:description" content=""/>
    <meta name="twitter:site" content="@labmlai"/>
    <meta name="twitter:creator" content="@labmlai"/>

    <meta property="og:url" content="https://nn.labml.ai/transformers/vit/readme.html"/>
    <meta property="og:title" content=" Vision Transformer (ViT)"/>
    <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta property="og:site_name" content="LabML Neural Networks"/>
    <meta property="og:type" content="object"/>
    <meta property="og:title" content=" Vision Transformer (ViT)"/>
    <meta property="og:description" content=""/>

    <title> Vision Transformer (ViT)</title>
    <link rel="shortcut icon" href="/icon.png"/>
    <link rel="stylesheet" href="../../pylit.css">
    <link rel="canonical" href="https://nn.labml.ai/transformers/vit/readme.html"/>
    <!-- Global site tag (gtag.js) - Google Analytics -->
    <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
    <script>
        window.dataLayer = window.dataLayer || [];

        function gtag() {
            dataLayer.push(arguments);
        }

        gtag('js', new Date());

        gtag('config', 'G-4V3HC8HBLH');
    </script>
</head>
<body>
<div id='container'>
    <div id="background"></div>
    <div class='section'>
        <div class='docs'>
            <p>
                <a class="parent" href="/">home</a>
                <a class="parent" href="../index.html">transformers</a>
                <a class="parent" href="index.html">vit</a>
            </p>
            <p>

                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/vit/readme.md">
                    <img alt="Github"
                         src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
                         style="max-width:100%;"/></a>
                <a href="https://twitter.com/labmlai"
                   rel="nofollow">
                    <img alt="Twitter"
                         src="https://img.shields.io/twitter/follow/labmlai?style=social"
                         style="max-width:100%;"/></a>
            </p>
        </div>
    </div>
    <div class='section' id='section-0'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-0'>#</a>
            </div>
            <h1><a href="https://nn.labml.ai/transformer/vit/index.html">Vision Transformer (ViT)</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://papers.labml.ai/paper/2010.11929">An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale</a>.</p>
<p>Vision transformer applies a pure transformer to images
without any convolution layers.
They split the image into patches and apply a transformer on patch embeddings.
<a href="https://nn.labml.ai/transformer/vit/index.html#PathEmbeddings">Patch embeddings</a> are generated by applying a simple linear transformation
to the flattened pixel values of the patch.
Then a standard transformer encoder is fed with the patch embeddings, along with a
classification token <code>[CLS]</code>.
The encoding on the <code>[CLS]</code> token is used to classify the image with an MLP.</p>
<p>When feeding the transformer with the patches, learned positional embeddings are
added to the patch embeddings, because the patch embeddings do not have any information
about where that patch is from.
The positional embeddings are a set of vectors for each patch location that get trained
with gradient descent along with other parameters.</p>
<p>ViTs perform well when they are pre-trained on large datasets.
The paper suggests pre-training them with an MLP classification head and
then using a single linear layer when fine-tuning.
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
They also use higher resolution images during inference while keeping the
patch size the same.
The positional embeddings for new patch locations are calculated by interpolating
learning positional embeddings.</p>
<p>Here&rsquo;s <a href="https://nn.labml.ai/transformer/vit/experiment.html">an experiment</a> that trains ViT on CIFAR-10.
This doesn&rsquo;t do very well because it&rsquo;s trained on a small dataset.
It&rsquo;s a simple experiment that anyone can run and play with ViTs.</p>
        </div>
        <div class='code'>
            
        </div>
    </div>
    <div class='footer'>
        <a href="https://papers.labml.ai">Trending Research Papers</a>
        <a href="https://labml.ai">labml.ai</a>
    </div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
    MathJax.Hub.Config({
        tex2jax: {
            inlineMath: [ ['$','$'] ],
            displayMath: [ ['$$','$$'] ],
            processEscapes: true,
            processEnvironments: true
        },
        // Center justify equations in code and markdown cells. Elsewhere
        // we use CSS to left justify single line equations in code cells.
        displayAlign: 'center',
        "HTML-CSS": { fonts: ["TeX"] }
    });

</script>
<script>
    function handleImages() {
        var images = document.querySelectorAll('p>img')

        console.log(images);
        for (var i = 0; i < images.length; ++i) {
            handleImage(images[i])
        }
    }

    function handleImage(img) {
        img.parentElement.style.textAlign = 'center'

        var modal = document.createElement('div')
        modal.id = 'modal'

        var modalContent = document.createElement('div')
        modal.appendChild(modalContent)

        var modalImage = document.createElement('img')
        modalContent.appendChild(modalImage)

        var span = document.createElement('span')
        span.classList.add('close')
        span.textContent = 'x'
        modal.appendChild(span)

        img.onclick = function () {
            console.log('clicked')
            document.body.appendChild(modal)
            modalImage.src = img.src
        }

        span.onclick = function () {
            document.body.removeChild(modal)
        }
    }

    handleImages()
</script>
</body>
</html>