Files
Varuna Jayasiri efd2673735 cleanup
2021-06-02 21:40:05 +05:30

2143 lines
147 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This is an annotated PyTorch implementation of the Sketch RNN from paper A Neural Representation of Sketch Drawings. Sketch RNN is a sequence-to-sequence model that generates sketches of objects such as bicycles, cats, etc."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Sketch RNN"/>
<meta name="twitter:description" content="This is an annotated PyTorch implementation of the Sketch RNN from paper A Neural Representation of Sketch Drawings. Sketch RNN is a sequence-to-sequence model that generates sketches of objects such as bicycles, cats, etc."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/sketch_rnn/index.html"/>
<meta property="og:title" content="Sketch RNN"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Sketch RNN"/>
<meta property="og:description" content="This is an annotated PyTorch implementation of the Sketch RNN from paper A Neural Representation of Sketch Drawings. Sketch RNN is a sequence-to-sequence model that generates sketches of objects such as bicycles, cats, etc."/>
<title>Sketch RNN</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/sketch_rnn/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">sketch_rnn</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/sketch_rnn/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Sketch RNN</h1>
<p>This is an annotated <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://arxiv.org/abs/1704.03477">A Neural Representation of Sketch Drawings</a>.</p>
<p>Sketch RNN is a sequence-to-sequence variational auto-encoder.
Both encoder and decoder are recurrent neural network models.
It learns to reconstruct stroke based simple drawings, by predicting
a series of strokes.
Decoder predicts each stroke as a mixture of Gaussian&rsquo;s.</p>
<h3>Getting data</h3>
<p>Download data from <a href="https://github.com/googlecreativelab/quickdraw-dataset">Quick, Draw! Dataset</a>.
There is a link to download <code>npz</code> files in <em>Sketch-RNN QuickDraw Dataset</em> section of the readme.
Place the downloaded <code>npz</code> file(s) in <code>data/sketch</code> folder.
This code is configured to use <code>bicycle</code> dataset.
You can change this in configurations.</p>
<h3>Acknowledgements</h3>
<p>Took help from <a href="https://github.com/alexis-jacq/Pytorch-Sketch-RNN">PyTorch Sketch RNN</a> project by
<a href="https://github.com/alexis-jacq">Alexis David Jacq</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span></span><span class="kn">import</span> <span class="nn">math</span>
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span>
<span class="lineno">34</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="lineno">36</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">37</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">38</span><span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="lineno">39</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">optim</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">DataLoader</span>
<span class="lineno">41</span>
<span class="lineno">42</span><span class="kn">import</span> <span class="nn">einops</span>
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">monit</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">labml_helpers.optimizer</span> <span class="kn">import</span> <span class="n">OptimizerConfigs</span>
<span class="lineno">47</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">TrainValidConfigs</span><span class="p">,</span> <span class="n">hook_model_outputs</span><span class="p">,</span> <span class="n">BatchIndex</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Dataset</h2>
<p>This class loads and pre-processes the data.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span><span class="k">class</span> <span class="nc">StrokesDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p><code>dataset</code> is a list of numpy arrays of shape [seq_len, 3].
It is a sequence of strokes, and each stroke is represented by
3 integers.
First two are the displacements along x and y ($\Delta x$, $\Delta y$)
and the last integer represents the state of the pen, $1$ if it&rsquo;s touching
the paper and $0$ otherwise.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">max_seq_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="n">data</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>We iterate through each of the sequences and filter</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="k">for</span> <span class="n">seq</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Filter if the length of the sequence of strokes is within our range</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">if</span> <span class="mi">10</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="n">max_seq_length</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
<span class="lineno">74</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="o">-</span><span class="mi">1000</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Convert to a floating point array and add to <code>data</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">77</span> <span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>We then calculate the scaling factor which is the
standard deviation of ($\Delta x$, $\Delta y$) combined.
Paper notes that the mean is not adjusted for simplicity,
since the mean is anyway close to $0$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">if</span> <span class="n">scale</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">84</span> <span class="n">scale</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">np</span><span class="o">.</span><span class="n">ravel</span><span class="p">(</span><span class="n">s</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">])</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">data</span><span class="p">]))</span>
<span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Get the longest sequence length among all sequences</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">longest_seq_len</span> <span class="o">=</span> <span class="nb">max</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span> <span class="k">for</span> <span class="n">seq</span> <span class="ow">in</span> <span class="n">data</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>We initialize PyTorch data array with two extra steps for start-of-sequence (sos)
and end-of-sequence (eos).
Each step is a vector $(\Delta x, \Delta y, p_1, p_2, p_3)$.
Only one of $p_1, p_2, p_3$ is $1$ and the others are $0$.
They represent <em>pen down</em>, <em>pen up</em> and <em>end-of-sequence</em> in that order.
$p_1$ is $1$ if the pen touches the paper in the next step.
$p_2$ is $1$ if the pen doesn&rsquo;t touch the paper in the next step.
$p_3$ is $1$ if it is the end of the drawing.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="n">longest_seq_len</span> <span class="o">+</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>The mask array needs only one extra-step since it is for the outputs of the
decoder, which takes in <code>data[:-1]</code> and predicts next step.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="n">longest_seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">102</span>
<span class="lineno">103</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">seq</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
<span class="lineno">104</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span>
<span class="lineno">105</span> <span class="n">len_seq</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Scale and set $\Delta x, \Delta y$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>$p_1$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>$p_2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>$p_3$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Mask is on until end of sequence</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Start-of-sequence is $(0, 0, 1, 0, 0)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Size of the dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Get a sample</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">idx</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<h2>Bi-variate Gaussian mixture</h2>
<p>The mixture is represented by $\Pi$ and
$\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.
This class adjusts temperatures and creates the categorical and Gaussian
distributions from the parameters.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span><span class="k">class</span> <span class="nc">BivariateGaussianMixture</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pi_logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mu_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mu_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">140</span> <span class="n">sigma_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sigma_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">rho_xy</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">141</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span> <span class="o">=</span> <span class="n">pi_logits</span>
<span class="lineno">142</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_x</span> <span class="o">=</span> <span class="n">mu_x</span>
<span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_y</span> <span class="o">=</span> <span class="n">mu_y</span>
<span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_x</span> <span class="o">=</span> <span class="n">sigma_x</span>
<span class="lineno">145</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_y</span> <span class="o">=</span> <span class="n">sigma_y</span>
<span class="lineno">146</span> <span class="bp">self</span><span class="o">.</span><span class="n">rho_xy</span> <span class="o">=</span> <span class="n">rho_xy</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Number of distributions in the mixture, $M$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="nd">@property</span>
<span class="lineno">149</span> <span class="k">def</span> <span class="nf">n_distributions</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Adjust by temperature $\tau$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="k">def</span> <span class="nf">set_temperature</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\hat{\Pi_k} \leftarrow \frac{\hat{\Pi_k}}{\tau}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span> <span class="o">/=</span> <span class="n">temperature</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\sigma^2_x \leftarrow \sigma^2_x \tau</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_x</span> <span class="o">*=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\sigma^2_y \leftarrow \sigma^2_y \tau</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_y</span> <span class="o">*=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">def</span> <span class="nf">get_distribution</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Clamp $\sigma_x$, $\sigma_y$ and $\rho_{xy}$ to avoid getting <code>NaN</code>s</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">sigma_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sigma_x</span><span class="p">,</span> <span class="mf">1e-5</span><span class="p">)</span>
<span class="lineno">167</span> <span class="n">sigma_y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sigma_y</span><span class="p">,</span> <span class="mf">1e-5</span><span class="p">)</span>
<span class="lineno">168</span> <span class="n">rho_xy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rho_xy</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span> <span class="o">+</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="mf">1e-5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Get means</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">mu_x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_y</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Get covariance matrix</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">cov</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span>
<span class="lineno">174</span> <span class="n">sigma_x</span> <span class="o">*</span> <span class="n">sigma_x</span><span class="p">,</span> <span class="n">rho_xy</span> <span class="o">*</span> <span class="n">sigma_x</span> <span class="o">*</span> <span class="n">sigma_y</span><span class="p">,</span>
<span class="lineno">175</span> <span class="n">rho_xy</span> <span class="o">*</span> <span class="n">sigma_x</span> <span class="o">*</span> <span class="n">sigma_y</span><span class="p">,</span> <span class="n">sigma_y</span> <span class="o">*</span> <span class="n">sigma_y</span>
<span class="lineno">176</span> <span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">177</span> <span class="n">cov</span> <span class="o">=</span> <span class="n">cov</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">sigma_y</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Create bi-variate normal distribution.</p>
<p>📝 It would be efficient to <code>scale_tril</code> matrix as <code>[[a, 0], [b, c]]</code>
where
<script type="math/tex; mode=display">a = \sigma_x, b = \rho_{xy} \sigma_y, c = \sigma_y \sqrt{1 - \rho^2_{xy}}</script>.
But for simplicity we use co-variance matrix.
<a href="https://www2.stat.duke.edu/courses/Spring12/sta104.1/Lectures/Lec22.pdf">This is a good resource</a>
if you want to read up more about bi-variate distributions, their co-variance matrix,
and probability density function.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">multi_dist</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">MultivariateNormal</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Create categorical distribution $\Pi$ from logits</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">cat_dist</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="k">return</span> <span class="n">cat_dist</span><span class="p">,</span> <span class="n">multi_dist</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<h2>Encoder module</h2>
<p>This consists of a bidirectional LSTM</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span><span class="k">class</span> <span class="nc">EncoderRNN</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_z</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">enc_hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">205</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Create a bidirectional LSTM taking a sequence of
$(\Delta x, \Delta y, p_1, p_2, p_3)$ as input.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">enc_hidden_size</span><span class="p">,</span> <span class="n">bidirectional</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Head to get $\mu$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">enc_hidden_size</span><span class="p">,</span> <span class="n">d_z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Head to get $\hat{\sigma}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">enc_hidden_size</span><span class="p">,</span> <span class="n">d_z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>The hidden state of the bidirectional LSTM is the concatenation of the
output of the last token in the forward direction and
first token in the reverse direction, which is what we want.
<script type="math/tex; mode=display">h_{\rightarrow} = encode_{\rightarrow}(S),
h_{\leftarrow} = encode_{\leftarrow}(S_{reverse}),
h = [h_{\rightarrow}; h_{\leftarrow}]</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">_</span><span class="p">,</span> <span class="p">(</span><span class="n">hidden</span><span class="p">,</span> <span class="n">cell</span><span class="p">)</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">state</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>The state has shape <code>[2, batch_size, hidden_size]</code>,
where the first dimension is the direction.
We rearrange it to get $h = [h_{\rightarrow}; h_{\leftarrow}]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">225</span> <span class="n">hidden</span> <span class="o">=</span> <span class="n">einops</span><span class="o">.</span><span class="n">rearrange</span><span class="p">(</span><span class="n">hidden</span><span class="p">,</span> <span class="s1">&#39;fb b h -&gt; b (fb h)&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>$\mu$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">mu</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_head</span><span class="p">(</span><span class="n">hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>$\hat{\sigma}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">sigma_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_head</span><span class="p">(</span><span class="n">hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>$\sigma = \exp(\frac{\hat{\sigma}}{2})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span> <span class="n">sigma</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_hat</span> <span class="o">/</span> <span class="mf">2.</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>Sample $z = \mu + \sigma \cdot \mathcal{N}(0, I)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">z</span> <span class="o">=</span> <span class="n">mu</span> <span class="o">+</span> <span class="n">sigma</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">mu</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">mu</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">mu</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">mu</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">238</span> <span class="k">return</span> <span class="n">z</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">sigma_hat</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<h2>Decoder module</h2>
<p>This consists of a LSTM</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">241</span><span class="k">class</span> <span class="nc">DecoderRNN</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">248</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_z</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dec_hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_distributions</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">249</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>LSTM takes $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$ as input</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">251</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="n">d_z</span> <span class="o">+</span> <span class="mi">5</span><span class="p">,</span> <span class="n">dec_hidden_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Initial state of the LSTM is $[h_0; c_0] = \tanh(W_{z}z + b_z)$.
<code>init_state</code> is the linear transformation for this</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">255</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_z</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">dec_hidden_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>This layer produces outputs for each of the <code>n_distributions</code>.
Each distribution needs six parameters
$(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}}, \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">260</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixtures</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="mi">6</span> <span class="o">*</span> <span class="n">n_distributions</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>This head is for the logits $(\hat{q_1}, \hat{q_2}, \hat{q_3})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>This is to calculate $\log(q_k)$ where
<script type="math/tex; mode=display">q_k = \operatorname{softmax}(\hat{q})_k = \frac{\exp(\hat{q_k})}{\sum_{j = 1}^3 \exp(\hat{q_j})}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">266</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_log_softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LogSoftmax</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>These parameters are stored for future reference</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">269</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_distributions</span> <span class="o">=</span> <span class="n">n_distributions</span>
<span class="lineno">270</span> <span class="bp">self</span><span class="o">.</span><span class="n">dec_hidden_size</span> <span class="o">=</span> <span class="n">dec_hidden_size</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">272</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]):</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Calculate the initial state</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">274</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>$[h_0; c_0] = \tanh(W_{z}z + b_z)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">z</span><span class="p">)),</span> <span class="bp">self</span><span class="o">.</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p><code>h</code> and <code>c</code> have shapes <code>[batch_size, lstm_size]</code>. We want to shape them
to <code>[1, batch_size, lstm_size]</code> because that&rsquo;s the shape used in LSTM.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">279</span> <span class="n">state</span> <span class="o">=</span> <span class="p">(</span><span class="n">h</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span> <span class="n">c</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>Run the LSTM</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Get $\log(q)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span> <span class="n">q_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_log_softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">q_head</span><span class="p">(</span><span class="n">outputs</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>Get $(\hat{\Pi_i}, \mu_{x,i}, \mu_{y,i}, \hat{\sigma_{x,i}},
\hat{\sigma_{y,i}} \hat{\rho_{xy,i}})$.
<code>torch.split</code> splits the output into 6 tensors of size <code>self.n_distribution</code>
across dimension <code>2</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">291</span> <span class="n">pi_logits</span><span class="p">,</span> <span class="n">mu_x</span><span class="p">,</span> <span class="n">mu_y</span><span class="p">,</span> <span class="n">sigma_x</span><span class="p">,</span> <span class="n">sigma_y</span><span class="p">,</span> <span class="n">rho_xy</span> <span class="o">=</span> \
<span class="lineno">292</span> <span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixtures</span><span class="p">(</span><span class="n">outputs</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_distributions</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p>Create a bi-variate Gaussian mixture
$\Pi$ and
$\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$
where
<script type="math/tex; mode=display">\sigma_{x,i} = \exp(\hat{\sigma_{x,i}}), \sigma_{y,i} = \exp(\hat{\sigma_{y,i}}),
\rho_{xy,i} = \tanh(\hat{\rho_{xy,i}})</script>
and
<script type="math/tex; mode=display">\Pi_i = \operatorname{softmax}(\hat{\Pi})_i = \frac{\exp(\hat{\Pi_i})}{\sum_{j = 1}^3 \exp(\hat{\Pi_j})}</script>
</p>
<p>$\Pi$ is the categorical probabilities of choosing the distribution out of the mixture
$\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">305</span> <span class="n">dist</span> <span class="o">=</span> <span class="n">BivariateGaussianMixture</span><span class="p">(</span><span class="n">pi_logits</span><span class="p">,</span> <span class="n">mu_x</span><span class="p">,</span> <span class="n">mu_y</span><span class="p">,</span>
<span class="lineno">306</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_x</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_y</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">rho_xy</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">309</span> <span class="k">return</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">state</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<h2>Reconstruction Loss</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span><span class="k">class</span> <span class="nc">ReconstructionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">317</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">318</span> <span class="n">dist</span><span class="p">:</span> <span class="s1">&#39;BivariateGaussianMixture&#39;</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">320</span> <span class="n">pi</span><span class="p">,</span> <span class="n">mix</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_distribution</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p><code>target</code> has shape <code>[seq_len, batch_size, 5]</code> where the last dimension is the features
$(\Delta x, \Delta y, p_1, p_2, p_3)$.
We want to get $\Delta x, \Delta$ y and get the probabilities from each of the distributions
in the mixture $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.</p>
<p><code>xy</code> will have shape <code>[seq_len, batch_size, n_distributions, 2]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">327</span> <span class="n">xy</span> <span class="o">=</span> <span class="n">target</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">dist</span><span class="o">.</span><span class="n">n_distributions</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>Calculate the probabilities
<script type="math/tex; mode=display">p(\Delta x, \Delta y) =
\sum_{j=1}^M \Pi_j \mathcal{N} \big( \Delta x, \Delta y \vert
\mu_{x,j}, \mu_{y,j}, \sigma_{x,j}, \sigma_{y,j}, \rho_{xy,j}
\big)</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">333</span> <span class="n">probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">pi</span><span class="o">.</span><span class="n">probs</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">mix</span><span class="o">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">xy</span><span class="p">)),</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>
<script type="math/tex; mode=display">L_s = - \frac{1}{N_{max}} \sum_{i=1}^{N_s} \log \big (p(\Delta x, \Delta y) \big)</script>
Although <code>probs</code> has $N_{max}$ (<code>longest_seq_len</code>) elements, the sum is only taken
upto $N_s$ because the rest is masked out.</p>
<p>It might feel like we should be taking the sum and dividing by $N_s$ and not $N_{max}$,
but this will give higher weight for individual predictions in shorter sequences.
We give equal weight to each prediction $p(\Delta x, \Delta y)$ when we divide by $N_{max}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">342</span> <span class="n">loss_stroke</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mask</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mf">1e-5</span> <span class="o">+</span> <span class="n">probs</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>
<script type="math/tex; mode=display">L_p = - \frac{1}{N_{max}} \sum_{i=1}^{N_{max}} \sum_{k=1}^{3} p_{k,i} \log(q_{k,i})</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">345</span> <span class="n">loss_pen</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">target</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">*</span> <span class="n">q_logits</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>
<script type="math/tex; mode=display">L_R = L_s + L_p</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">348</span> <span class="k">return</span> <span class="n">loss_stroke</span> <span class="o">+</span> <span class="n">loss_pen</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<h2>KL-Divergence loss</h2>
<p>This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">351</span><span class="k">class</span> <span class="nc">KLDivLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">358</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sigma_hat</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mu</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>
<script type="math/tex; mode=display">L_{KL} = - \frac{1}{2 N_z} \bigg( 1 + \hat{\sigma} - \mu^2 - \exp(\hat{\sigma}) \bigg)</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">360</span> <span class="k">return</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">sigma_hat</span> <span class="o">-</span> <span class="n">mu</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_hat</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<h2>Sampler</h2>
<p>This samples a sketch from the decoder and plots it</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">363</span><span class="k">class</span> <span class="nc">Sampler</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">370</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">EncoderRNN</span><span class="p">,</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">DecoderRNN</span><span class="p">):</span>
<span class="lineno">371</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span>
<span class="lineno">372</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">374</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>$N_{max}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">376</span> <span class="n">longest_seq_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>Get $z$ from the encoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">379</span> <span class="n">z</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>Start-of-sequence stroke is $(0, 0, 1, 0, 0)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">382</span> <span class="n">s</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
<span class="lineno">383</span> <span class="n">seq</span> <span class="o">=</span> <span class="p">[</span><span class="n">s</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>Initial decoder is <code>None</code>.
The decoder will initialize it to $[h_0; c_0] = \tanh(W_{z}z + b_z)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">386</span> <span class="n">state</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<p>We don&rsquo;t need gradients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">389</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>Sample $N_{max}$ strokes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">391</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">longest_seq_len</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>$[(\Delta x, \Delta y, p_1, p_2, p_3); z]$ is the input to the decoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">393</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">s</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">z</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)],</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<p>Get $\Pi$, $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$,
$q$ and the next state from the decoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">396</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>Sample a stroke</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">398</span> <span class="n">s</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sample_step</span><span class="p">(</span><span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">temperature</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>Add the new stroke to the sequence of strokes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">400</span> <span class="n">seq</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">s</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>Stop sampling if $p_3 = 1$. This indicates that sketching has stopped</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">402</span> <span class="k">if</span> <span class="n">s</span><span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">403</span> <span class="k">break</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>Create a PyTorch tensor of the sequence of strokes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">406</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>Plot the sequence of strokes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">409</span> <span class="bp">self</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">411</span> <span class="nd">@staticmethod</span>
<span class="lineno">412</span> <span class="k">def</span> <span class="nf">_sample_step</span><span class="p">(</span><span class="n">dist</span><span class="p">:</span> <span class="s1">&#39;BivariateGaussianMixture&#39;</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>Set temperature $\tau$ for sampling. This is implemented in class <code>BivariateGaussianMixture</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">414</span> <span class="n">dist</span><span class="o">.</span><span class="n">set_temperature</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p>Get temperature adjusted $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">416</span> <span class="n">pi</span><span class="p">,</span> <span class="n">mix</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_distribution</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>Sample from $\Pi$ the index of the distribution to use from the mixture</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">418</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">sample</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-98'>
<div class='docs'>
<div class='section-link'>
<a href='#section-98'>#</a>
</div>
<p>Create categorical distribution $q$ with log-probabilities <code>q_logits</code> or $\hat{q}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">421</span> <span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">q_logits</span> <span class="o">/</span> <span class="n">temperature</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-99'>
<div class='docs'>
<div class='section-link'>
<a href='#section-99'>#</a>
</div>
<p>Sample from $q$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">423</span> <span class="n">q_idx</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">sample</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-100'>
<div class='docs'>
<div class='section-link'>
<a href='#section-100'>#</a>
</div>
<p>Sample from the normal distributions in the mixture and pick the one indexed by <code>idx</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">426</span> <span class="n">xy</span> <span class="o">=</span> <span class="n">mix</span><span class="o">.</span><span class="n">sample</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-101'>
<div class='docs'>
<div class='section-link'>
<a href='#section-101'>#</a>
</div>
<p>Create an empty stroke $(\Delta x, \Delta y, q_1, q_2, q_3)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">429</span> <span class="n">stroke</span> <span class="o">=</span> <span class="n">q_logits</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-102'>
<div class='docs'>
<div class='section-link'>
<a href='#section-102'>#</a>
</div>
<p>Set $\Delta x, \Delta y$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">431</span> <span class="n">stroke</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">xy</span></pre></div>
</div>
</div>
<div class='section' id='section-103'>
<div class='docs'>
<div class='section-link'>
<a href='#section-103'>#</a>
</div>
<p>Set $q_1, q_2, q_3$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">433</span> <span class="n">stroke</span><span class="p">[</span><span class="n">q_idx</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-104'>
<div class='docs'>
<div class='section-link'>
<a href='#section-104'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">435</span> <span class="k">return</span> <span class="n">stroke</span></pre></div>
</div>
</div>
<div class='section' id='section-105'>
<div class='docs'>
<div class='section-link'>
<a href='#section-105'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">437</span> <span class="nd">@staticmethod</span>
<span class="lineno">438</span> <span class="k">def</span> <span class="nf">plot</span><span class="p">(</span><span class="n">seq</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-106'>
<div class='docs'>
<div class='section-link'>
<a href='#section-106'>#</a>
</div>
<p>Take the cumulative sums of $(\Delta x, \Delta y)$ to get $(x, y)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">440</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">seq</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-107'>
<div class='docs'>
<div class='section-link'>
<a href='#section-107'>#</a>
</div>
<p>Create a new numpy array of the form $(x, y, q_2)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">442</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
<span class="lineno">443</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-108'>
<div class='docs'>
<div class='section-link'>
<a href='#section-108'>#</a>
</div>
<p>Split the array at points where $q_2$ is $1$.
i.e. split the array of strokes at the points where the pen is lifted from the paper.
This gives a list of sequence of strokes.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">448</span> <span class="n">strokes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-109'>
<div class='docs'>
<div class='section-link'>
<a href='#section-109'>#</a>
</div>
<p>Plot each sequence of strokes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">450</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">strokes</span><span class="p">:</span>
<span class="lineno">451</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">s</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="n">s</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-110'>
<div class='docs'>
<div class='section-link'>
<a href='#section-110'>#</a>
</div>
<p>Don&rsquo;t show axes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">453</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-111'>
<div class='docs'>
<div class='section-link'>
<a href='#section-111'>#</a>
</div>
<p>Show the plot</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">455</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-112'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-112'>#</a>
</div>
<h2>Configurations</h2>
<p>These are default configurations which can later be adjusted by passing a <code>dict</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">458</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-113'>
<div class='docs'>
<div class='section-link'>
<a href='#section-113'>#</a>
</div>
<p>Device configurations to pick the device to run the experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">466</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-114'>
<div class='docs'>
<div class='section-link'>
<a href='#section-114'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">468</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">EncoderRNN</span>
<span class="lineno">469</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">DecoderRNN</span>
<span class="lineno">470</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
<span class="lineno">471</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span>
<span class="lineno">472</span>
<span class="lineno">473</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span>
<span class="lineno">474</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span>
<span class="lineno">475</span> <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span>
<span class="lineno">476</span> <span class="n">train_dataset</span><span class="p">:</span> <span class="n">StrokesDataset</span>
<span class="lineno">477</span> <span class="n">valid_dataset</span><span class="p">:</span> <span class="n">StrokesDataset</span></pre></div>
</div>
</div>
<div class='section' id='section-115'>
<div class='docs'>
<div class='section-link'>
<a href='#section-115'>#</a>
</div>
<p>Encoder and decoder sizes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">480</span> <span class="n">enc_hidden_size</span> <span class="o">=</span> <span class="mi">256</span>
<span class="lineno">481</span> <span class="n">dec_hidden_size</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
</div>
</div>
<div class='section' id='section-116'>
<div class='docs'>
<div class='section-link'>
<a href='#section-116'>#</a>
</div>
<p>Batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">484</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">100</span></pre></div>
</div>
</div>
<div class='section' id='section-117'>
<div class='docs'>
<div class='section-link'>
<a href='#section-117'>#</a>
</div>
<p>Number of features in $z$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">487</span> <span class="n">d_z</span> <span class="o">=</span> <span class="mi">128</span></pre></div>
</div>
</div>
<div class='section' id='section-118'>
<div class='docs'>
<div class='section-link'>
<a href='#section-118'>#</a>
</div>
<p>Number of distributions in the mixture, $M$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">489</span> <span class="n">n_distributions</span> <span class="o">=</span> <span class="mi">20</span></pre></div>
</div>
</div>
<div class='section' id='section-119'>
<div class='docs'>
<div class='section-link'>
<a href='#section-119'>#</a>
</div>
<p>Weight of KL divergence loss, $w_{KL}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">492</span> <span class="n">kl_div_loss_weight</span> <span class="o">=</span> <span class="mf">0.5</span></pre></div>
</div>
</div>
<div class='section' id='section-120'>
<div class='docs'>
<div class='section-link'>
<a href='#section-120'>#</a>
</div>
<p>Gradient clipping</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">494</span> <span class="n">grad_clip</span> <span class="o">=</span> <span class="mf">1.</span></pre></div>
</div>
</div>
<div class='section' id='section-121'>
<div class='docs'>
<div class='section-link'>
<a href='#section-121'>#</a>
</div>
<p>Temperature $\tau$ for sampling</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">496</span> <span class="n">temperature</span> <span class="o">=</span> <span class="mf">0.4</span></pre></div>
</div>
</div>
<div class='section' id='section-122'>
<div class='docs'>
<div class='section-link'>
<a href='#section-122'>#</a>
</div>
<p>Filter out stroke sequences longer than $200$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">499</span> <span class="n">max_seq_length</span> <span class="o">=</span> <span class="mi">200</span>
<span class="lineno">500</span>
<span class="lineno">501</span> <span class="n">epochs</span> <span class="o">=</span> <span class="mi">100</span>
<span class="lineno">502</span>
<span class="lineno">503</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="n">KLDivLoss</span><span class="p">()</span>
<span class="lineno">504</span> <span class="n">reconstruction_loss</span> <span class="o">=</span> <span class="n">ReconstructionLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-123'>
<div class='docs'>
<div class='section-link'>
<a href='#section-123'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">506</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-124'>
<div class='docs'>
<div class='section-link'>
<a href='#section-124'>#</a>
</div>
<p>Initialize encoder &amp; decoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">508</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">EncoderRNN</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_z</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">enc_hidden_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">509</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">DecoderRNN</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_z</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_distributions</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-125'>
<div class='docs'>
<div class='section-link'>
<a href='#section-125'>#</a>
</div>
<p>Set optimizer. Things like type of optimizer and learning rate are configurable</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">512</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">OptimizerConfigs</span><span class="p">()</span>
<span class="lineno">513</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
<span class="lineno">514</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span></pre></div>
</div>
</div>
<div class='section' id='section-126'>
<div class='docs'>
<div class='section-link'>
<a href='#section-126'>#</a>
</div>
<p>Create sampler</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">517</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">Sampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-127'>
<div class='docs'>
<div class='section-link'>
<a href='#section-127'>#</a>
</div>
<p><code>npz</code> file path is <code>data/sketch/[DATASET NAME].npz</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">520</span> <span class="n">path</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;sketch&#39;</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span><span class="si">}</span><span class="s1">.npz&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-128'>
<div class='docs'>
<div class='section-link'>
<a href='#section-128'>#</a>
</div>
<p>Load the numpy file</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">522</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">),</span> <span class="n">encoding</span><span class="o">=</span><span class="s1">&#39;latin1&#39;</span><span class="p">,</span> <span class="n">allow_pickle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-129'>
<div class='docs'>
<div class='section-link'>
<a href='#section-129'>#</a>
</div>
<p>Create training dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">525</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_dataset</span> <span class="o">=</span> <span class="n">StrokesDataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="s1">&#39;train&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-130'>
<div class='docs'>
<div class='section-link'>
<a href='#section-130'>#</a>
</div>
<p>Create validation dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">527</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span> <span class="o">=</span> <span class="n">StrokesDataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="s1">&#39;valid&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_dataset</span><span class="o">.</span><span class="n">scale</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-131'>
<div class='docs'>
<div class='section-link'>
<a href='#section-131'>#</a>
</div>
<p>Create training data loader</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">530</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_dataset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-132'>
<div class='docs'>
<div class='section-link'>
<a href='#section-132'>#</a>
</div>
<p>Create validation data loader</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">532</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-133'>
<div class='docs'>
<div class='section-link'>
<a href='#section-133'>#</a>
</div>
<p>Add hooks to monitor layer outputs on Tensorboard</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">535</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="s1">&#39;encoder&#39;</span><span class="p">)</span>
<span class="lineno">536</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="s1">&#39;decoder&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-134'>
<div class='docs'>
<div class='section-link'>
<a href='#section-134'>#</a>
</div>
<p>Configure the tracker to print the total train/validation loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">539</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;loss.total.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">540</span>
<span class="lineno">541</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-135'>
<div class='docs'>
<div class='section-link'>
<a href='#section-135'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">543</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span>
<span class="lineno">544</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span>
<span class="lineno">545</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-136'>
<div class='docs'>
<div class='section-link'>
<a href='#section-136'>#</a>
</div>
<p>Move <code>data</code> and <code>mask</code> to device and swap the sequence and batch dimensions.
<code>data</code> will have shape <code>[seq_len, batch_size, 5]</code> and
<code>mask</code> will have shape <code>[seq_len, batch_size]</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">550</span> <span class="n">data</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">551</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-137'>
<div class='docs'>
<div class='section-link'>
<a href='#section-137'>#</a>
</div>
<p>Increment step in training mode</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">554</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">555</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-138'>
<div class='docs'>
<div class='section-link'>
<a href='#section-138'>#</a>
</div>
<p>Encode the sequence of strokes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">558</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;encoder&quot;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-139'>
<div class='docs'>
<div class='section-link'>
<a href='#section-139'>#</a>
</div>
<p>Get $z$, $\mu$, and $\hat{\sigma}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">560</span> <span class="n">z</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">sigma_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-140'>
<div class='docs'>
<div class='section-link'>
<a href='#section-140'>#</a>
</div>
<p>Decode the mixture of distributions and $\hat{q}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">563</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;decoder&quot;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-141'>
<div class='docs'>
<div class='section-link'>
<a href='#section-141'>#</a>
</div>
<p>Concatenate $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">565</span> <span class="n">z_stack</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">566</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">data</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">z_stack</span><span class="p">],</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-142'>
<div class='docs'>
<div class='section-link'>
<a href='#section-142'>#</a>
</div>
<p>Get mixture of distributions and $\hat{q}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">568</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-143'>
<div class='docs'>
<div class='section-link'>
<a href='#section-143'>#</a>
</div>
<p>Compute the loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">571</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;loss&#39;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-144'>
<div class='docs'>
<div class='section-link'>
<a href='#section-144'>#</a>
</div>
<p>$L_{KL}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">573</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss</span><span class="p">(</span><span class="n">sigma_hat</span><span class="p">,</span> <span class="n">mu</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-145'>
<div class='docs'>
<div class='section-link'>
<a href='#section-145'>#</a>
</div>
<p>$L_R$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">575</span> <span class="n">reconstruction_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reconstruction_loss</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-146'>
<div class='docs'>
<div class='section-link'>
<a href='#section-146'>#</a>
</div>
<p>$Loss = L_R + w_{KL} L_{KL}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">577</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">reconstruction_loss</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss_weight</span> <span class="o">*</span> <span class="n">kl_loss</span></pre></div>
</div>
</div>
<div class='section' id='section-147'>
<div class='docs'>
<div class='section-link'>
<a href='#section-147'>#</a>
</div>
<p>Track losses</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">580</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.kl.&quot;</span><span class="p">,</span> <span class="n">kl_loss</span><span class="p">)</span>
<span class="lineno">581</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.reconstruction.&quot;</span><span class="p">,</span> <span class="n">reconstruction_loss</span><span class="p">)</span>
<span class="lineno">582</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.total.&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-148'>
<div class='docs'>
<div class='section-link'>
<a href='#section-148'>#</a>
</div>
<p>Only if we are in training state</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">585</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-149'>
<div class='docs'>
<div class='section-link'>
<a href='#section-149'>#</a>
</div>
<p>Run optimizer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">587</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;optimize&#39;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-150'>
<div class='docs'>
<div class='section-link'>
<a href='#section-150'>#</a>
</div>
<p>Set <code>grad</code> to zero</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">589</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-151'>
<div class='docs'>
<div class='section-link'>
<a href='#section-151'>#</a>
</div>
<p>Compute gradients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">591</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-152'>
<div class='docs'>
<div class='section-link'>
<a href='#section-152'>#</a>
</div>
<p>Log model parameters and gradients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">593</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
<span class="lineno">594</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">encoder</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-153'>
<div class='docs'>
<div class='section-link'>
<a href='#section-153'>#</a>
</div>
<p>Clip gradients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">596</span> <span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_clip</span><span class="p">)</span>
<span class="lineno">597</span> <span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_clip</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-154'>
<div class='docs'>
<div class='section-link'>
<a href='#section-154'>#</a>
</div>
<p>Optimize</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">599</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="lineno">600</span>
<span class="lineno">601</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-155'>
<div class='docs'>
<div class='section-link'>
<a href='#section-155'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">603</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-156'>
<div class='docs'>
<div class='section-link'>
<a href='#section-156'>#</a>
</div>
<p>Randomly pick a sample from validation dataset to encoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">605</span> <span class="n">data</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">))]</span></pre></div>
</div>
</div>
<div class='section' id='section-157'>
<div class='docs'>
<div class='section-link'>
<a href='#section-157'>#</a>
</div>
<p>Add batch dimension and move it to device</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">607</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-158'>
<div class='docs'>
<div class='section-link'>
<a href='#section-158'>#</a>
</div>
<p>Sample</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">609</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-159'>
<div class='docs'>
<div class='section-link'>
<a href='#section-159'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">612</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="lineno">613</span> <span class="n">configs</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span>
<span class="lineno">614</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;sketch_rnn&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-160'>
<div class='docs'>
<div class='section-link'>
<a href='#section-160'>#</a>
</div>
<p>Pass a dictionary of configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">617</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">618</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-161'>
<div class='docs'>
<div class='section-link'>
<a href='#section-161'>#</a>
</div>
<p>We use a learning rate of <code>1e-3</code> because we can see results faster.
Paper had suggested <code>1e-4</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">621</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">1e-3</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-162'>
<div class='docs'>
<div class='section-link'>
<a href='#section-162'>#</a>
</div>
<p>Name of the dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">623</span> <span class="s1">&#39;dataset_name&#39;</span><span class="p">:</span> <span class="s1">&#39;bicycle&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-163'>
<div class='docs'>
<div class='section-link'>
<a href='#section-163'>#</a>
</div>
<p>Number of inner iterations within an epoch to switch between training, validation and sampling.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">625</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">10</span>
<span class="lineno">626</span> <span class="p">})</span>
<span class="lineno">627</span>
<span class="lineno">628</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-164'>
<div class='docs'>
<div class='section-link'>
<a href='#section-164'>#</a>
</div>
<p>Run the experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">630</span> <span class="n">configs</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
<span class="lineno">631</span>
<span class="lineno">632</span>
<span class="lineno">633</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="lineno">634</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>