Files
Varuna Jayasiri 62c5786d31 html
2022-02-23 15:15:59 +05:30

907 lines
56 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Implementation/tutorial of GPT model and training code."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="GPT"/>
<meta name="twitter:description" content="Implementation/tutorial of GPT model and training code."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/gpt/index.html"/>
<meta property="og:title" content="GPT"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="GPT"/>
<meta property="og:description" content="Implementation/tutorial of GPT model and training code."/>
<title>GPT</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/transformers/gpt/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">gpt</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/gpt/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>GPT</h1>
<p>This is a tutorial/implementation of <a href="https://openai.com/blog/better-language-models/">OpenAI GPT architecture</a> in <a href="https://pytorch.org">PyTorch</a>. We got a bunch of implementation details from <a href="https://github.com/karpathy/minGPT">minGPT</a> by <a href="https://twitter.com/karpathy">@karpathy</a>. This implementation also uses character tiny shakespeare dataset.</p>
<p>GPT model is essentially a standard transformer with a few tweaks. GPT-2 and especially GPT-3 models are quite large and won&#x27;t fit on a single GPU and will need model parallelism. This implementation doesn&#x27;t even use data parallelism and is intended to be more of a tutorial.</p>
<p>Main differences of this compared to a simple autoregressive transformer are the parameter initialization, weight decay, and learning rate schedule. For the transformer we reuse the <a href="../transformers/index.html">existing labml/nn transformer implementation</a>.</p>
<p>Here&#x27;s a notebook for training a GPT model on Tiny Shakespeare dataset.</p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/gpt/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/0324c6d0562111eba65d0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">36</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">37</span>
<span class="lineno">38</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">39</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">41</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
<span class="lineno">42</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers.configs</span> <span class="kn">import</span> <span class="n">OptimizerConfigs</span>
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span><span class="p">,</span> <span class="n">Encoder</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>GPT model</h2>
<p>This consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span><span class="k">class</span> <span class="nc">GPT</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">encoder</span></code>
is the transformer <a href="../models.html#Encoder">Encoder</a> </li>
<li><code class="highlight"><span></span><span class="n">src_embed</span></code>
is the token <a href="../models.html#EmbeddingsWithLearnedPositionalEncoding">embedding module (with positional encodings)</a> </li>
<li><code class="highlight"><span></span><span class="n">generator</span></code>
is the <a href="../models.html#Generator">final fully connected layer</a> that gives the logits.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">63</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span>
<span class="lineno">64</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span>
<span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>The mask will be initialized on the first call </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Create subsequent mask if mask is not initialized or if the size of the mask is different </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Subsequent mask, will mask out tokens from seeing future tokens </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="n">subsequent_mask</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Get the token embeddings with positional encodings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Transformer encoder </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Get logits </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Return results (second value is for state, since our trainer is used with RNNs also) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<h2>Configurations</h2>
<p>This inherits from <a href="../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs"><code class="highlight"><span></span><span class="n">NLPAutoRegressionConfigs</span></code>
</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>GPT model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">model</span><span class="p">:</span> <span class="n">GPT</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Transformer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">TransformerConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Weight decay </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Number of tokens for wamup </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">warmup_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span> <span class="o">*</span> <span class="mi">128</span> <span class="o">*</span> <span class="mi">20</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Custom optimizer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="s1">&#39;transformer_optimizer&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<h3>Transformer configurations</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">,</span> <span class="s1">&#39;GPT&#39;</span><span class="p">)</span>
<span class="lineno">110</span><span class="k">def</span> <span class="nf">_transformer_configs</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>We use our <a href="../configs.html#TransformerConfigs">configurable transformer implementation</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">TransformerConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Set the vocabulary sizes for embeddings and generating logits </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_src_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
<span class="lineno">120</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_tgt_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>GPT uses GELU activation for position wise feedforward </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="n">conf</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="s1">&#39;GELU&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="k">return</span> <span class="n">conf</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<h3>Initialize weights</h3>
<p>Weights of linear layers and embedding layers are initialized to <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span><span class="mord">.</span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span><span class="mord">2</span><span class="mclose">)</span></span></span></span> instead of the default Xavier initialzation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span><span class="k">def</span> <span class="nf">_init_weights</span><span class="p">(</span><span class="n">module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">)):</span>
<span class="lineno">138</span> <span class="k">return</span>
<span class="lineno">139</span>
<span class="lineno">140</span> <span class="n">module</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Initialize biases to <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">)</span> <span class="ow">and</span> <span class="n">module</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">144</span> <span class="n">module</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">zero_</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p> Create GPT model and initialize weights</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">148</span><span class="k">def</span> <span class="nf">_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="n">m</span> <span class="o">=</span> <span class="n">GPT</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span>
<span class="lineno">153</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span>
<span class="lineno">154</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Apply custom weight initialization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">m</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">_init_weights</span><span class="p">)</span>
<span class="lineno">158</span>
<span class="lineno">159</span> <span class="k">return</span> <span class="n">m</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<h3>Create custom optimizer with weight decay</h3>
<p>This code is taken from <a href="https://github.com/karpathy/minGPT">minGPT</a>. This applies weight decay only to weights of linear layers.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
<span class="lineno">163</span><span class="k">def</span> <span class="nf">transformer_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Collect names of parameters to apply weight decay </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">decay</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
<span class="lineno">172</span> <span class="k">for</span> <span class="n">mn</span><span class="p">,</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">c</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">named_modules</span><span class="p">():</span>
<span class="lineno">173</span> <span class="k">for</span> <span class="n">pn</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">m</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">():</span>
<span class="lineno">174</span> <span class="n">fpn</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">mn</span><span class="si">}</span><span class="s1">.</span><span class="si">{</span><span class="n">pn</span><span class="si">}</span><span class="s1">&#39;</span> <span class="k">if</span> <span class="n">mn</span> <span class="k">else</span> <span class="n">pn</span> <span class="c1"># full param name</span>
<span class="lineno">175</span>
<span class="lineno">176</span> <span class="k">if</span> <span class="n">fpn</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">&#39;weight&#39;</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">):</span>
<span class="lineno">177</span> <span class="n">decay</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">fpn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Get all the parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">180</span> <span class="n">param_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">pn</span><span class="p">:</span> <span class="n">p</span> <span class="k">for</span> <span class="n">pn</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">c</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">()}</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Parameters that are not decayed </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">no_decay</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">param_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="o">-</span> <span class="n">decay</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>create the pytorch optimizer object </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="n">opt_groups</span> <span class="o">=</span> <span class="p">[</span>
<span class="lineno">186</span> <span class="p">{</span><span class="s2">&quot;params&quot;</span><span class="p">:</span> <span class="p">[</span><span class="n">param_dict</span><span class="p">[</span><span class="n">pn</span><span class="p">]</span> <span class="k">for</span> <span class="n">pn</span> <span class="ow">in</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">decay</span><span class="p">))],</span> <span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">},</span>
<span class="lineno">187</span> <span class="p">{</span><span class="s2">&quot;params&quot;</span><span class="p">:</span> <span class="p">[</span><span class="n">param_dict</span><span class="p">[</span><span class="n">pn</span><span class="p">]</span> <span class="k">for</span> <span class="n">pn</span> <span class="ow">in</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">no_decay</span><span class="p">))],</span> <span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">},</span>
<span class="lineno">188</span> <span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Create a <a href="../optimizers/configs.html#OptimizerConfigs">configurable optimizer</a>, so that we can change these simply by passing a config dictionary. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">OptimizerConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Set parameter groups for optimization. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="n">opt_groups</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Use <a href="../optimizers/adam_warmup_cosine_decay.html">cosine decay optimizer</a>. This is what GPT uses. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="s1">&#39;AdamWarmupCosineDecay&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Set model embedding size, required if we use <a href="../optimizers/noam.html">Noam optimizer</a> which has an exponential decay. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Set default weight decay. This is not required since we set the weight decay in the parameter groups. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">weight_decay</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>GPT uses a maximum learning rate of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">6</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.848448em;vertical-align:0em;"></span><span class="mord"><span class="mord coloredeq eqf" style=""><span class="mord" style="">1</span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.848448em;"><span style="top:-3.09734em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"></span><span class="mord mtight">4</span></span></span></span></span></span></span></span></span></span></span></span>. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">6e-4</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span><span class="mord">.9</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span><span class="mord">.95</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">ϵ</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.848448em;vertical-align:0em;"></span><span class="mord"><span class="mord coloredeq eqf" style=""><span class="mord" style="">1</span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.848448em;"><span style="top:-3.09734em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"></span><span class="mord mtight">8</span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="mf">1e-8</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Weight decay is decoupled from gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">weight_decouple</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Total number of optimization steps for learning rate cosine decay </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">total_steps</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">epochs</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">train</span><span class="p">)</span> <span class="o">//</span> <span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Number of warmup optimization steps </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">warmup</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">warmup_steps</span> <span class="o">//</span> <span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">)</span>
<span class="lineno">218</span>
<span class="lineno">219</span> <span class="k">return</span> <span class="n">optimizer</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Create experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;gpt&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>Create configs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>Override configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Use character level tokenizer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="s1">&#39;tokenizer&#39;</span><span class="p">:</span> <span class="s1">&#39;character&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Prompt separator is blank </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span> <span class="s1">&#39;prompt_separator&#39;</span><span class="p">:</span> <span class="s1">&#39;&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Starting prompt for sampling </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">234</span> <span class="s1">&#39;prompt&#39;</span><span class="p">:</span> <span class="s1">&#39;It is &#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>Use Tiny Shakespeare dataset </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">236</span> <span class="s1">&#39;text&#39;</span><span class="p">:</span> <span class="s1">&#39;tiny_shakespeare&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Use a context size of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">128</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Train for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">32</span></span></span></span> epochs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">241</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Batch size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">128</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>Switch between training and validation for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style="">1</span><span class="mord coloredeq eqh" style=""><span class="mord" style="">0</span></span></span></span></span></span> times per epoch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Transformer configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">249</span> <span class="s1">&#39;transformer.d_model&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
<span class="lineno">250</span> <span class="s1">&#39;transformer.ffn.d_ff&#39;</span><span class="p">:</span> <span class="mi">2048</span><span class="p">,</span>
<span class="lineno">251</span> <span class="s1">&#39;transformer.n_heads&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="lineno">252</span> <span class="s1">&#39;transformer.n_layers&#39;</span><span class="p">:</span> <span class="mi">6</span>
<span class="lineno">253</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Set models for saving and loading </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">256</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Start the experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">259</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>Run training </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">261</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">266</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>