mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
2143 lines
147 KiB
HTML
2143 lines
147 KiB
HTML
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
|
<meta name="description" content="This is an annotated PyTorch implementation of the Sketch RNN from paper A Neural Representation of Sketch Drawings. Sketch RNN is a sequence-to-sequence model that generates sketches of objects such as bicycles, cats, etc."/>
|
|
|
|
<meta name="twitter:card" content="summary"/>
|
|
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
|
<meta name="twitter:title" content="Sketch RNN"/>
|
|
<meta name="twitter:description" content="This is an annotated PyTorch implementation of the Sketch RNN from paper A Neural Representation of Sketch Drawings. Sketch RNN is a sequence-to-sequence model that generates sketches of objects such as bicycles, cats, etc."/>
|
|
<meta name="twitter:site" content="@labmlai"/>
|
|
<meta name="twitter:creator" content="@labmlai"/>
|
|
|
|
<meta property="og:url" content="https://nn.labml.ai/sketch_rnn/index.html"/>
|
|
<meta property="og:title" content="Sketch RNN"/>
|
|
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
|
<meta property="og:site_name" content="LabML Neural Networks"/>
|
|
<meta property="og:type" content="object"/>
|
|
<meta property="og:title" content="Sketch RNN"/>
|
|
<meta property="og:description" content="This is an annotated PyTorch implementation of the Sketch RNN from paper A Neural Representation of Sketch Drawings. Sketch RNN is a sequence-to-sequence model that generates sketches of objects such as bicycles, cats, etc."/>
|
|
|
|
<title>Sketch RNN</title>
|
|
<link rel="shortcut icon" href="/icon.png"/>
|
|
<link rel="stylesheet" href="../pylit.css">
|
|
<link rel="canonical" href="https://nn.labml.ai/sketch_rnn/index.html"/>
|
|
<!-- Global site tag (gtag.js) - Google Analytics -->
|
|
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
|
<script>
|
|
window.dataLayer = window.dataLayer || [];
|
|
|
|
function gtag() {
|
|
dataLayer.push(arguments);
|
|
}
|
|
|
|
gtag('js', new Date());
|
|
|
|
gtag('config', 'G-4V3HC8HBLH');
|
|
</script>
|
|
</head>
|
|
<body>
|
|
<div id='container'>
|
|
<div id="background"></div>
|
|
<div class='section'>
|
|
<div class='docs'>
|
|
<p>
|
|
<a class="parent" href="/">home</a>
|
|
<a class="parent" href="index.html">sketch_rnn</a>
|
|
</p>
|
|
<p>
|
|
|
|
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/sketch_rnn/__init__.py">
|
|
<img alt="Github"
|
|
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
|
style="max-width:100%;"/></a>
|
|
<a href="https://twitter.com/labmlai"
|
|
rel="nofollow">
|
|
<img alt="Twitter"
|
|
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
|
style="max-width:100%;"/></a>
|
|
</p>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-0'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-0'>#</a>
|
|
</div>
|
|
<h1>Sketch RNN</h1>
|
|
<p>This is an annotated <a href="https://pytorch.org">PyTorch</a> implementation of the paper
|
|
<a href="https://arxiv.org/abs/1704.03477">A Neural Representation of Sketch Drawings</a>.</p>
|
|
<p>Sketch RNN is a sequence-to-sequence variational auto-encoder.
|
|
Both encoder and decoder are recurrent neural network models.
|
|
It learns to reconstruct stroke based simple drawings, by predicting
|
|
a series of strokes.
|
|
Decoder predicts each stroke as a mixture of Gaussian’s.</p>
|
|
<h3>Getting data</h3>
|
|
<p>Download data from <a href="https://github.com/googlecreativelab/quickdraw-dataset">Quick, Draw! Dataset</a>.
|
|
There is a link to download <code>npz</code> files in <em>Sketch-RNN QuickDraw Dataset</em> section of the readme.
|
|
Place the downloaded <code>npz</code> file(s) in <code>data/sketch</code> folder.
|
|
This code is configured to use <code>bicycle</code> dataset.
|
|
You can change this in configurations.</p>
|
|
<h3>Acknowledgements</h3>
|
|
<p>Took help from <a href="https://github.com/alexis-jacq/Pytorch-Sketch-RNN">PyTorch Sketch RNN</a> project by
|
|
<a href="https://github.com/alexis-jacq">Alexis David Jacq</a></p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">32</span><span></span><span class="kn">import</span> <span class="nn">math</span>
|
|
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span>
|
|
<span class="lineno">34</span>
|
|
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
|
<span class="lineno">36</span><span class="kn">import</span> <span class="nn">torch</span>
|
|
<span class="lineno">37</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
|
<span class="lineno">38</span><span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
|
|
<span class="lineno">39</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">optim</span>
|
|
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">DataLoader</span>
|
|
<span class="lineno">41</span>
|
|
<span class="lineno">42</span><span class="kn">import</span> <span class="nn">einops</span>
|
|
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">monit</span>
|
|
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
|
|
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
|
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">labml_helpers.optimizer</span> <span class="kn">import</span> <span class="n">OptimizerConfigs</span>
|
|
<span class="lineno">47</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">TrainValidConfigs</span><span class="p">,</span> <span class="n">hook_model_outputs</span><span class="p">,</span> <span class="n">BatchIndex</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-1'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-1'>#</a>
|
|
</div>
|
|
<h2>Dataset</h2>
|
|
<p>This class loads and pre-processes the data.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">50</span><span class="k">class</span> <span class="nc">StrokesDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-2'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-2'>#</a>
|
|
</div>
|
|
<p><code>dataset</code> is a list of numpy arrays of shape [seq_len, 3].
|
|
It is a sequence of strokes, and each stroke is represented by
|
|
3 integers.
|
|
First two are the displacements along x and y ($\Delta x$, $\Delta y$)
|
|
and the last integer represents the state of the pen, $1$ if it’s touching
|
|
the paper and $0$ otherwise.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">max_seq_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-3'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-3'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">67</span> <span class="n">data</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-4'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-4'>#</a>
|
|
</div>
|
|
<p>We iterate through each of the sequences and filter</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">69</span> <span class="k">for</span> <span class="n">seq</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">:</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-5'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-5'>#</a>
|
|
</div>
|
|
<p>Filter if the length of the sequence of strokes is within our range</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">if</span> <span class="mi">10</span> <span class="o"><</span> <span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span> <span class="o"><=</span> <span class="n">max_seq_length</span><span class="p">:</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-6'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-6'>#</a>
|
|
</div>
|
|
<p>Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
|
|
<span class="lineno">74</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="o">-</span><span class="mi">1000</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-7'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-7'>#</a>
|
|
</div>
|
|
<p>Convert to a floating point array and add to <code>data</code></p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">76</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
<span class="lineno">77</span> <span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-8'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-8'>#</a>
|
|
</div>
|
|
<p>We then calculate the scaling factor which is the
|
|
standard deviation of ($\Delta x$, $\Delta y$) combined.
|
|
Paper notes that the mean is not adjusted for simplicity,
|
|
since the mean is anyway close to $0$.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">if</span> <span class="n">scale</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="lineno">84</span> <span class="n">scale</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">np</span><span class="o">.</span><span class="n">ravel</span><span class="p">(</span><span class="n">s</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">])</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">data</span><span class="p">]))</span>
|
|
<span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">scale</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-9'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-9'>#</a>
|
|
</div>
|
|
<p>Get the longest sequence length among all sequences</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">longest_seq_len</span> <span class="o">=</span> <span class="nb">max</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span> <span class="k">for</span> <span class="n">seq</span> <span class="ow">in</span> <span class="n">data</span><span class="p">])</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-10'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-10'>#</a>
|
|
</div>
|
|
<p>We initialize PyTorch data array with two extra steps for start-of-sequence (sos)
|
|
and end-of-sequence (eos).
|
|
Each step is a vector $(\Delta x, \Delta y, p_1, p_2, p_3)$.
|
|
Only one of $p_1, p_2, p_3$ is $1$ and the others are $0$.
|
|
They represent <em>pen down</em>, <em>pen up</em> and <em>end-of-sequence</em> in that order.
|
|
$p_1$ is $1$ if the pen touches the paper in the next step.
|
|
$p_2$ is $1$ if the pen doesn’t touch the paper in the next step.
|
|
$p_3$ is $1$ if it is the end of the drawing.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">98</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="n">longest_seq_len</span> <span class="o">+</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-11'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-11'>#</a>
|
|
</div>
|
|
<p>The mask array needs only one extra-step since it is for the outputs of the
|
|
decoder, which takes in <code>data[:-1]</code> and predicts next step.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">),</span> <span class="n">longest_seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="lineno">102</span>
|
|
<span class="lineno">103</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">seq</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
|
|
<span class="lineno">104</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span>
|
|
<span class="lineno">105</span> <span class="n">len_seq</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-12'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-12'>#</a>
|
|
</div>
|
|
<p>Scale and set $\Delta x, \Delta y$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="n">scale</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-13'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-13'>#</a>
|
|
</div>
|
|
<p>$p_1$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-14'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-14'>#</a>
|
|
</div>
|
|
<p>$p_2$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-15'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-15'>#</a>
|
|
</div>
|
|
<p>$p_3$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-16'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-16'>#</a>
|
|
</div>
|
|
<p>Mask is on until end of sequence</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:</span><span class="n">len_seq</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-17'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-17'>#</a>
|
|
</div>
|
|
<p>Start-of-sequence is $(0, 0, 1, 0, 0)</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-18'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-18'>#</a>
|
|
</div>
|
|
<p>Size of the dataset</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-19'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-19'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-20'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-20'>#</a>
|
|
</div>
|
|
<p>Get a sample</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-21'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-21'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">126</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">idx</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-22'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-22'>#</a>
|
|
</div>
|
|
<h2>Bi-variate Gaussian mixture</h2>
|
|
<p>The mixture is represented by $\Pi$ and
|
|
$\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.
|
|
This class adjusts temperatures and creates the categorical and Gaussian
|
|
distributions from the parameters.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">129</span><span class="k">class</span> <span class="nc">BivariateGaussianMixture</span><span class="p">:</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-23'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-23'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pi_logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mu_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mu_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="lineno">140</span> <span class="n">sigma_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sigma_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">rho_xy</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="lineno">141</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span> <span class="o">=</span> <span class="n">pi_logits</span>
|
|
<span class="lineno">142</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_x</span> <span class="o">=</span> <span class="n">mu_x</span>
|
|
<span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_y</span> <span class="o">=</span> <span class="n">mu_y</span>
|
|
<span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_x</span> <span class="o">=</span> <span class="n">sigma_x</span>
|
|
<span class="lineno">145</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_y</span> <span class="o">=</span> <span class="n">sigma_y</span>
|
|
<span class="lineno">146</span> <span class="bp">self</span><span class="o">.</span><span class="n">rho_xy</span> <span class="o">=</span> <span class="n">rho_xy</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-24'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-24'>#</a>
|
|
</div>
|
|
<p>Number of distributions in the mixture, $M$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">148</span> <span class="nd">@property</span>
|
|
<span class="lineno">149</span> <span class="k">def</span> <span class="nf">n_distributions</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-25'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-25'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">151</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-26'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-26'>#</a>
|
|
</div>
|
|
<p>Adjust by temperature $\tau$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">153</span> <span class="k">def</span> <span class="nf">set_temperature</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-27'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-27'>#</a>
|
|
</div>
|
|
<p>
|
|
<script type="math/tex; mode=display">\hat{\Pi_k} \leftarrow \frac{\hat{\Pi_k}}{\tau}</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">158</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span> <span class="o">/=</span> <span class="n">temperature</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-28'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-28'>#</a>
|
|
</div>
|
|
<p>
|
|
<script type="math/tex; mode=display">\sigma^2_x \leftarrow \sigma^2_x \tau</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_x</span> <span class="o">*=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-29'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-29'>#</a>
|
|
</div>
|
|
<p>
|
|
<script type="math/tex; mode=display">\sigma^2_y \leftarrow \sigma^2_y \tau</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">162</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_y</span> <span class="o">*=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-30'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-30'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">def</span> <span class="nf">get_distribution</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-31'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-31'>#</a>
|
|
</div>
|
|
<p>Clamp $\sigma_x$, $\sigma_y$ and $\rho_{xy}$ to avoid getting <code>NaN</code>s</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">sigma_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sigma_x</span><span class="p">,</span> <span class="mf">1e-5</span><span class="p">)</span>
|
|
<span class="lineno">167</span> <span class="n">sigma_y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp_min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sigma_y</span><span class="p">,</span> <span class="mf">1e-5</span><span class="p">)</span>
|
|
<span class="lineno">168</span> <span class="n">rho_xy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rho_xy</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span> <span class="o">+</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="mf">1e-5</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-32'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-32'>#</a>
|
|
</div>
|
|
<p>Get means</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">mu_x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_y</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-33'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-33'>#</a>
|
|
</div>
|
|
<p>Get covariance matrix</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">cov</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span>
|
|
<span class="lineno">174</span> <span class="n">sigma_x</span> <span class="o">*</span> <span class="n">sigma_x</span><span class="p">,</span> <span class="n">rho_xy</span> <span class="o">*</span> <span class="n">sigma_x</span> <span class="o">*</span> <span class="n">sigma_y</span><span class="p">,</span>
|
|
<span class="lineno">175</span> <span class="n">rho_xy</span> <span class="o">*</span> <span class="n">sigma_x</span> <span class="o">*</span> <span class="n">sigma_y</span><span class="p">,</span> <span class="n">sigma_y</span> <span class="o">*</span> <span class="n">sigma_y</span>
|
|
<span class="lineno">176</span> <span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="lineno">177</span> <span class="n">cov</span> <span class="o">=</span> <span class="n">cov</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">sigma_y</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-34'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-34'>#</a>
|
|
</div>
|
|
<p>Create bi-variate normal distribution.</p>
|
|
<p>📝 It would be efficient to <code>scale_tril</code> matrix as <code>[[a, 0], [b, c]]</code>
|
|
where
|
|
<script type="math/tex; mode=display">a = \sigma_x, b = \rho_{xy} \sigma_y, c = \sigma_y \sqrt{1 - \rho^2_{xy}}</script>.
|
|
But for simplicity we use co-variance matrix.
|
|
<a href="https://www2.stat.duke.edu/courses/Spring12/sta104.1/Lectures/Lec22.pdf">This is a good resource</a>
|
|
if you want to read up more about bi-variate distributions, their co-variance matrix,
|
|
and probability density function.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">multi_dist</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">MultivariateNormal</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-35'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-35'>#</a>
|
|
</div>
|
|
<p>Create categorical distribution $\Pi$ from logits</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">cat_dist</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pi_logits</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-36'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-36'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">194</span> <span class="k">return</span> <span class="n">cat_dist</span><span class="p">,</span> <span class="n">multi_dist</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-37'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-37'>#</a>
|
|
</div>
|
|
<h2>Encoder module</h2>
|
|
<p>This consists of a bidirectional LSTM</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">197</span><span class="k">class</span> <span class="nc">EncoderRNN</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-38'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-38'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">204</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_z</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">enc_hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="lineno">205</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-39'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-39'>#</a>
|
|
</div>
|
|
<p>Create a bidirectional LSTM taking a sequence of
|
|
$(\Delta x, \Delta y, p_1, p_2, p_3)$ as input.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">208</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">enc_hidden_size</span><span class="p">,</span> <span class="n">bidirectional</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-40'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-40'>#</a>
|
|
</div>
|
|
<p>Head to get $\mu$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">210</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">enc_hidden_size</span><span class="p">,</span> <span class="n">d_z</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-41'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-41'>#</a>
|
|
</div>
|
|
<p>Head to get $\hat{\sigma}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">212</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">enc_hidden_size</span><span class="p">,</span> <span class="n">d_z</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-42'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-42'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">214</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-43'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-43'>#</a>
|
|
</div>
|
|
<p>The hidden state of the bidirectional LSTM is the concatenation of the
|
|
output of the last token in the forward direction and
|
|
first token in the reverse direction, which is what we want.
|
|
<script type="math/tex; mode=display">h_{\rightarrow} = encode_{\rightarrow}(S),
|
|
h_{\leftarrow} = encode←_{\leftarrow}(S_{reverse}),
|
|
h = [h_{\rightarrow}; h_{\leftarrow}]</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">_</span><span class="p">,</span> <span class="p">(</span><span class="n">hidden</span><span class="p">,</span> <span class="n">cell</span><span class="p">)</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">state</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-44'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-44'>#</a>
|
|
</div>
|
|
<p>The state has shape <code>[2, batch_size, hidden_size]</code>,
|
|
where the first dimension is the direction.
|
|
We rearrange it to get $h = [h_{\rightarrow}; h_{\leftarrow}]$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">225</span> <span class="n">hidden</span> <span class="o">=</span> <span class="n">einops</span><span class="o">.</span><span class="n">rearrange</span><span class="p">(</span><span class="n">hidden</span><span class="p">,</span> <span class="s1">'fb b h -> b (fb h)'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-45'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-45'>#</a>
|
|
</div>
|
|
<p>$\mu$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">mu</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mu_head</span><span class="p">(</span><span class="n">hidden</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-46'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-46'>#</a>
|
|
</div>
|
|
<p>$\hat{\sigma}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">sigma_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma_head</span><span class="p">(</span><span class="n">hidden</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-47'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-47'>#</a>
|
|
</div>
|
|
<p>$\sigma = \exp(\frac{\hat{\sigma}}{2})$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">232</span> <span class="n">sigma</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_hat</span> <span class="o">/</span> <span class="mf">2.</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-48'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-48'>#</a>
|
|
</div>
|
|
<p>Sample $z = \mu + \sigma \cdot \mathcal{N}(0, I)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">z</span> <span class="o">=</span> <span class="n">mu</span> <span class="o">+</span> <span class="n">sigma</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">mu</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">mu</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">mu</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">mu</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-49'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-49'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">238</span> <span class="k">return</span> <span class="n">z</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">sigma_hat</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-50'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-50'>#</a>
|
|
</div>
|
|
<h2>Decoder module</h2>
|
|
<p>This consists of a LSTM</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">241</span><span class="k">class</span> <span class="nc">DecoderRNN</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-51'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-51'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">248</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_z</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dec_hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_distributions</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="lineno">249</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-52'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-52'>#</a>
|
|
</div>
|
|
<p>LSTM takes $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$ as input</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">251</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="n">d_z</span> <span class="o">+</span> <span class="mi">5</span><span class="p">,</span> <span class="n">dec_hidden_size</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-53'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-53'>#</a>
|
|
</div>
|
|
<p>Initial state of the LSTM is $[h_0; c_0] = \tanh(W_{z}z + b_z)$.
|
|
<code>init_state</code> is the linear transformation for this</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">255</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_state</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_z</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">dec_hidden_size</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-54'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-54'>#</a>
|
|
</div>
|
|
<p>This layer produces outputs for each of the <code>n_distributions</code>.
|
|
Each distribution needs six parameters
|
|
$(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}}, \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">260</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixtures</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="mi">6</span> <span class="o">*</span> <span class="n">n_distributions</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-55'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-55'>#</a>
|
|
</div>
|
|
<p>This head is for the logits $(\hat{q_1}, \hat{q_2}, \hat{q_3})$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">263</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-56'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-56'>#</a>
|
|
</div>
|
|
<p>This is to calculate $\log(q_k)$ where
|
|
<script type="math/tex; mode=display">q_k = \operatorname{softmax}(\hat{q})_k = \frac{\exp(\hat{q_k})}{\sum_{j = 1}^3 \exp(\hat{q_j})}</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">266</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_log_softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LogSoftmax</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-57'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-57'>#</a>
|
|
</div>
|
|
<p>These parameters are stored for future reference</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">269</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_distributions</span> <span class="o">=</span> <span class="n">n_distributions</span>
|
|
<span class="lineno">270</span> <span class="bp">self</span><span class="o">.</span><span class="n">dec_hidden_size</span> <span class="o">=</span> <span class="n">dec_hidden_size</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-58'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-58'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">272</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-59'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-59'>#</a>
|
|
</div>
|
|
<p>Calculate the initial state</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">274</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-60'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-60'>#</a>
|
|
</div>
|
|
<p>$[h_0; c_0] = \tanh(W_{z}z + b_z)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">276</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">init_state</span><span class="p">(</span><span class="n">z</span><span class="p">)),</span> <span class="bp">self</span><span class="o">.</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-61'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-61'>#</a>
|
|
</div>
|
|
<p><code>h</code> and <code>c</code> have shapes <code>[batch_size, lstm_size]</code>. We want to shape them
|
|
to <code>[1, batch_size, lstm_size]</code> because that’s the shape used in LSTM.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">279</span> <span class="n">state</span> <span class="o">=</span> <span class="p">(</span><span class="n">h</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span> <span class="n">c</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">())</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-62'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-62'>#</a>
|
|
</div>
|
|
<p>Run the LSTM</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lstm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-63'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-63'>#</a>
|
|
</div>
|
|
<p>Get $\log(q)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">285</span> <span class="n">q_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_log_softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">q_head</span><span class="p">(</span><span class="n">outputs</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-64'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-64'>#</a>
|
|
</div>
|
|
<p>Get $(\hat{\Pi_i}, \mu_{x,i}, \mu_{y,i}, \hat{\sigma_{x,i}},
|
|
\hat{\sigma_{y,i}} \hat{\rho_{xy,i}})$.
|
|
<code>torch.split</code> splits the output into 6 tensors of size <code>self.n_distribution</code>
|
|
across dimension <code>2</code>.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">291</span> <span class="n">pi_logits</span><span class="p">,</span> <span class="n">mu_x</span><span class="p">,</span> <span class="n">mu_y</span><span class="p">,</span> <span class="n">sigma_x</span><span class="p">,</span> <span class="n">sigma_y</span><span class="p">,</span> <span class="n">rho_xy</span> <span class="o">=</span> \
|
|
<span class="lineno">292</span> <span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixtures</span><span class="p">(</span><span class="n">outputs</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_distributions</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-65'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-65'>#</a>
|
|
</div>
|
|
<p>Create a bi-variate Gaussian mixture
|
|
$\Pi$ and
|
|
$\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$
|
|
where
|
|
<script type="math/tex; mode=display">\sigma_{x,i} = \exp(\hat{\sigma_{x,i}}), \sigma_{y,i} = \exp(\hat{\sigma_{y,i}}),
|
|
\rho_{xy,i} = \tanh(\hat{\rho_{xy,i}})</script>
|
|
and
|
|
<script type="math/tex; mode=display">\Pi_i = \operatorname{softmax}(\hat{\Pi})_i = \frac{\exp(\hat{\Pi_i})}{\sum_{j = 1}^3 \exp(\hat{\Pi_j})}</script>
|
|
</p>
|
|
<p>$\Pi$ is the categorical probabilities of choosing the distribution out of the mixture
|
|
$\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">305</span> <span class="n">dist</span> <span class="o">=</span> <span class="n">BivariateGaussianMixture</span><span class="p">(</span><span class="n">pi_logits</span><span class="p">,</span> <span class="n">mu_x</span><span class="p">,</span> <span class="n">mu_y</span><span class="p">,</span>
|
|
<span class="lineno">306</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_x</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_y</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">rho_xy</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-66'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-66'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">309</span> <span class="k">return</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">state</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-67'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-67'>#</a>
|
|
</div>
|
|
<h2>Reconstruction Loss</h2>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">312</span><span class="k">class</span> <span class="nc">ReconstructionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-68'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-68'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">317</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="lineno">318</span> <span class="n">dist</span><span class="p">:</span> <span class="s1">'BivariateGaussianMixture'</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-69'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-69'>#</a>
|
|
</div>
|
|
<p>Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">320</span> <span class="n">pi</span><span class="p">,</span> <span class="n">mix</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_distribution</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-70'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-70'>#</a>
|
|
</div>
|
|
<p><code>target</code> has shape <code>[seq_len, batch_size, 5]</code> where the last dimension is the features
|
|
$(\Delta x, \Delta y, p_1, p_2, p_3)$.
|
|
We want to get $\Delta x, \Delta$ y and get the probabilities from each of the distributions
|
|
in the mixture $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.</p>
|
|
<p><code>xy</code> will have shape <code>[seq_len, batch_size, n_distributions, 2]</code></p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">327</span> <span class="n">xy</span> <span class="o">=</span> <span class="n">target</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">dist</span><span class="o">.</span><span class="n">n_distributions</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-71'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-71'>#</a>
|
|
</div>
|
|
<p>Calculate the probabilities
|
|
<script type="math/tex; mode=display">p(\Delta x, \Delta y) =
|
|
\sum_{j=1}^M \Pi_j \mathcal{N} \big( \Delta x, \Delta y \vert
|
|
\mu_{x,j}, \mu_{y,j}, \sigma_{x,j}, \sigma_{y,j}, \rho_{xy,j}
|
|
\big)</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">333</span> <span class="n">probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">pi</span><span class="o">.</span><span class="n">probs</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">mix</span><span class="o">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">xy</span><span class="p">)),</span> <span class="mi">2</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-72'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-72'>#</a>
|
|
</div>
|
|
<p>
|
|
<script type="math/tex; mode=display">L_s = - \frac{1}{N_{max}} \sum_{i=1}^{N_s} \log \big (p(\Delta x, \Delta y) \big)</script>
|
|
Although <code>probs</code> has $N_{max}$ (<code>longest_seq_len</code>) elements, the sum is only taken
|
|
upto $N_s$ because the rest is masked out.</p>
|
|
<p>It might feel like we should be taking the sum and dividing by $N_s$ and not $N_{max}$,
|
|
but this will give higher weight for individual predictions in shorter sequences.
|
|
We give equal weight to each prediction $p(\Delta x, \Delta y)$ when we divide by $N_{max}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">342</span> <span class="n">loss_stroke</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">mask</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mf">1e-5</span> <span class="o">+</span> <span class="n">probs</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-73'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-73'>#</a>
|
|
</div>
|
|
<p>
|
|
<script type="math/tex; mode=display">L_p = - \frac{1}{N_{max}} \sum_{i=1}^{N_{max}} \sum_{k=1}^{3} p_{k,i} \log(q_{k,i})</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">345</span> <span class="n">loss_pen</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">target</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">*</span> <span class="n">q_logits</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-74'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-74'>#</a>
|
|
</div>
|
|
<p>
|
|
<script type="math/tex; mode=display">L_R = L_s + L_p</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">348</span> <span class="k">return</span> <span class="n">loss_stroke</span> <span class="o">+</span> <span class="n">loss_pen</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-75'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-75'>#</a>
|
|
</div>
|
|
<h2>KL-Divergence loss</h2>
|
|
<p>This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">351</span><span class="k">class</span> <span class="nc">KLDivLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-76'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-76'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">358</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sigma_hat</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mu</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-77'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-77'>#</a>
|
|
</div>
|
|
<p>
|
|
<script type="math/tex; mode=display">L_{KL} = - \frac{1}{2 N_z} \bigg( 1 + \hat{\sigma} - \mu^2 - \exp(\hat{\sigma}) \bigg)</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">360</span> <span class="k">return</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">sigma_hat</span> <span class="o">-</span> <span class="n">mu</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">sigma_hat</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-78'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-78'>#</a>
|
|
</div>
|
|
<h2>Sampler</h2>
|
|
<p>This samples a sketch from the decoder and plots it</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">363</span><span class="k">class</span> <span class="nc">Sampler</span><span class="p">:</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-79'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-79'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">370</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">EncoderRNN</span><span class="p">,</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">DecoderRNN</span><span class="p">):</span>
|
|
<span class="lineno">371</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span>
|
|
<span class="lineno">372</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-80'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-80'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">374</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-81'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-81'>#</a>
|
|
</div>
|
|
<p>$N_{max}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">376</span> <span class="n">longest_seq_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-82'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-82'>#</a>
|
|
</div>
|
|
<p>Get $z$ from the encoder</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">379</span> <span class="n">z</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-83'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-83'>#</a>
|
|
</div>
|
|
<p>Start-of-sequence stroke is $(0, 0, 1, 0, 0)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">382</span> <span class="n">s</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
|
|
<span class="lineno">383</span> <span class="n">seq</span> <span class="o">=</span> <span class="p">[</span><span class="n">s</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-84'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-84'>#</a>
|
|
</div>
|
|
<p>Initial decoder is <code>None</code>.
|
|
The decoder will initialize it to $[h_0; c_0] = \tanh(W_{z}z + b_z)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">386</span> <span class="n">state</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-85'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-85'>#</a>
|
|
</div>
|
|
<p>We don’t need gradients</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">389</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-86'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-86'>#</a>
|
|
</div>
|
|
<p>Sample $N_{max}$ strokes</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">391</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">longest_seq_len</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-87'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-87'>#</a>
|
|
</div>
|
|
<p>$[(\Delta x, \Delta y, p_1, p_2, p_3); z]$ is the input to the decoder</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">393</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">s</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">z</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)],</span> <span class="mi">2</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-88'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-88'>#</a>
|
|
</div>
|
|
<p>Get $\Pi$, $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$,
|
|
$q$ and the next state from the decoder</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">396</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-89'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-89'>#</a>
|
|
</div>
|
|
<p>Sample a stroke</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">398</span> <span class="n">s</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sample_step</span><span class="p">(</span><span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">temperature</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-90'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-90'>#</a>
|
|
</div>
|
|
<p>Add the new stroke to the sequence of strokes</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">400</span> <span class="n">seq</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">s</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-91'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-91'>#</a>
|
|
</div>
|
|
<p>Stop sampling if $p_3 = 1$. This indicates that sketching has stopped</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">402</span> <span class="k">if</span> <span class="n">s</span><span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="lineno">403</span> <span class="k">break</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-92'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-92'>#</a>
|
|
</div>
|
|
<p>Create a PyTorch tensor of the sequence of strokes</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">406</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-93'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-93'>#</a>
|
|
</div>
|
|
<p>Plot the sequence of strokes</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">409</span> <span class="bp">self</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-94'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-94'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">411</span> <span class="nd">@staticmethod</span>
|
|
<span class="lineno">412</span> <span class="k">def</span> <span class="nf">_sample_step</span><span class="p">(</span><span class="n">dist</span><span class="p">:</span> <span class="s1">'BivariateGaussianMixture'</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-95'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-95'>#</a>
|
|
</div>
|
|
<p>Set temperature $\tau$ for sampling. This is implemented in class <code>BivariateGaussianMixture</code>.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">414</span> <span class="n">dist</span><span class="o">.</span><span class="n">set_temperature</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-96'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-96'>#</a>
|
|
</div>
|
|
<p>Get temperature adjusted $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">416</span> <span class="n">pi</span><span class="p">,</span> <span class="n">mix</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_distribution</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-97'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-97'>#</a>
|
|
</div>
|
|
<p>Sample from $\Pi$ the index of the distribution to use from the mixture</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">418</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">pi</span><span class="o">.</span><span class="n">sample</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-98'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-98'>#</a>
|
|
</div>
|
|
<p>Create categorical distribution $q$ with log-probabilities <code>q_logits</code> or $\hat{q}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">421</span> <span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">q_logits</span> <span class="o">/</span> <span class="n">temperature</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-99'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-99'>#</a>
|
|
</div>
|
|
<p>Sample from $q$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">423</span> <span class="n">q_idx</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">sample</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-100'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-100'>#</a>
|
|
</div>
|
|
<p>Sample from the normal distributions in the mixture and pick the one indexed by <code>idx</code></p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">426</span> <span class="n">xy</span> <span class="o">=</span> <span class="n">mix</span><span class="o">.</span><span class="n">sample</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-101'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-101'>#</a>
|
|
</div>
|
|
<p>Create an empty stroke $(\Delta x, \Delta y, q_1, q_2, q_3)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">429</span> <span class="n">stroke</span> <span class="o">=</span> <span class="n">q_logits</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-102'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-102'>#</a>
|
|
</div>
|
|
<p>Set $\Delta x, \Delta y$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">431</span> <span class="n">stroke</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">xy</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-103'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-103'>#</a>
|
|
</div>
|
|
<p>Set $q_1, q_2, q_3$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">433</span> <span class="n">stroke</span><span class="p">[</span><span class="n">q_idx</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-104'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-104'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">435</span> <span class="k">return</span> <span class="n">stroke</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-105'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-105'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">437</span> <span class="nd">@staticmethod</span>
|
|
<span class="lineno">438</span> <span class="k">def</span> <span class="nf">plot</span><span class="p">(</span><span class="n">seq</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-106'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-106'>#</a>
|
|
</div>
|
|
<p>Take the cumulative sums of $(\Delta x, \Delta y)$ to get $(x, y)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">440</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">seq</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-107'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-107'>#</a>
|
|
</div>
|
|
<p>Create a new numpy array of the form $(x, y, q_2)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">442</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
|
|
<span class="lineno">443</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">seq</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-108'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-108'>#</a>
|
|
</div>
|
|
<p>Split the array at points where $q_2$ is $1$.
|
|
i.e. split the array of strokes at the points where the pen is lifted from the paper.
|
|
This gives a list of sequence of strokes.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">448</span> <span class="n">strokes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">seq</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">></span> <span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-109'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-109'>#</a>
|
|
</div>
|
|
<p>Plot each sequence of strokes</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">450</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">strokes</span><span class="p">:</span>
|
|
<span class="lineno">451</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">s</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="n">s</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-110'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-110'>#</a>
|
|
</div>
|
|
<p>Don’t show axes</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">453</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">'off'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-111'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-111'>#</a>
|
|
</div>
|
|
<p>Show the plot</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">455</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-112'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-112'>#</a>
|
|
</div>
|
|
<h2>Configurations</h2>
|
|
<p>These are default configurations which can later be adjusted by passing a <code>dict</code>.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">458</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-113'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-113'>#</a>
|
|
</div>
|
|
<p>Device configurations to pick the device to run the experiment</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">466</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-114'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-114'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">468</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">EncoderRNN</span>
|
|
<span class="lineno">469</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">DecoderRNN</span>
|
|
<span class="lineno">470</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
|
|
<span class="lineno">471</span> <span class="n">sampler</span><span class="p">:</span> <span class="n">Sampler</span>
|
|
<span class="lineno">472</span>
|
|
<span class="lineno">473</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span>
|
|
<span class="lineno">474</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span>
|
|
<span class="lineno">475</span> <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span>
|
|
<span class="lineno">476</span> <span class="n">train_dataset</span><span class="p">:</span> <span class="n">StrokesDataset</span>
|
|
<span class="lineno">477</span> <span class="n">valid_dataset</span><span class="p">:</span> <span class="n">StrokesDataset</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-115'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-115'>#</a>
|
|
</div>
|
|
<p>Encoder and decoder sizes</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">480</span> <span class="n">enc_hidden_size</span> <span class="o">=</span> <span class="mi">256</span>
|
|
<span class="lineno">481</span> <span class="n">dec_hidden_size</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-116'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-116'>#</a>
|
|
</div>
|
|
<p>Batch size</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">484</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">100</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-117'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-117'>#</a>
|
|
</div>
|
|
<p>Number of features in $z$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">487</span> <span class="n">d_z</span> <span class="o">=</span> <span class="mi">128</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-118'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-118'>#</a>
|
|
</div>
|
|
<p>Number of distributions in the mixture, $M$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">489</span> <span class="n">n_distributions</span> <span class="o">=</span> <span class="mi">20</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-119'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-119'>#</a>
|
|
</div>
|
|
<p>Weight of KL divergence loss, $w_{KL}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">492</span> <span class="n">kl_div_loss_weight</span> <span class="o">=</span> <span class="mf">0.5</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-120'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-120'>#</a>
|
|
</div>
|
|
<p>Gradient clipping</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">494</span> <span class="n">grad_clip</span> <span class="o">=</span> <span class="mf">1.</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-121'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-121'>#</a>
|
|
</div>
|
|
<p>Temperature $\tau$ for sampling</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">496</span> <span class="n">temperature</span> <span class="o">=</span> <span class="mf">0.4</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-122'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-122'>#</a>
|
|
</div>
|
|
<p>Filter out stroke sequences longer than $200$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">499</span> <span class="n">max_seq_length</span> <span class="o">=</span> <span class="mi">200</span>
|
|
<span class="lineno">500</span>
|
|
<span class="lineno">501</span> <span class="n">epochs</span> <span class="o">=</span> <span class="mi">100</span>
|
|
<span class="lineno">502</span>
|
|
<span class="lineno">503</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="n">KLDivLoss</span><span class="p">()</span>
|
|
<span class="lineno">504</span> <span class="n">reconstruction_loss</span> <span class="o">=</span> <span class="n">ReconstructionLoss</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-123'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-123'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">506</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-124'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-124'>#</a>
|
|
</div>
|
|
<p>Initialize encoder & decoder</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">508</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">EncoderRNN</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_z</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">enc_hidden_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="lineno">509</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">DecoderRNN</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_z</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dec_hidden_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_distributions</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-125'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-125'>#</a>
|
|
</div>
|
|
<p>Set optimizer. Things like type of optimizer and learning rate are configurable</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">512</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">OptimizerConfigs</span><span class="p">()</span>
|
|
<span class="lineno">513</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
|
|
<span class="lineno">514</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-126'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-126'>#</a>
|
|
</div>
|
|
<p>Create sampler</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">517</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">Sampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-127'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-127'>#</a>
|
|
</div>
|
|
<p><code>npz</code> file path is <code>data/sketch/[DATASET NAME].npz</code></p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">520</span> <span class="n">path</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">'sketch'</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span><span class="si">}</span><span class="s1">.npz'</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-128'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-128'>#</a>
|
|
</div>
|
|
<p>Load the numpy file</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">522</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">),</span> <span class="n">encoding</span><span class="o">=</span><span class="s1">'latin1'</span><span class="p">,</span> <span class="n">allow_pickle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-129'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-129'>#</a>
|
|
</div>
|
|
<p>Create training dataset</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">525</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_dataset</span> <span class="o">=</span> <span class="n">StrokesDataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="s1">'train'</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-130'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-130'>#</a>
|
|
</div>
|
|
<p>Create validation dataset</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">527</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span> <span class="o">=</span> <span class="n">StrokesDataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="s1">'valid'</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_dataset</span><span class="o">.</span><span class="n">scale</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-131'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-131'>#</a>
|
|
</div>
|
|
<p>Create training data loader</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">530</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_dataset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-132'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-132'>#</a>
|
|
</div>
|
|
<p>Create validation data loader</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">532</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-133'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-133'>#</a>
|
|
</div>
|
|
<p>Add hooks to monitor layer outputs on Tensorboard</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">535</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="s1">'encoder'</span><span class="p">)</span>
|
|
<span class="lineno">536</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="s1">'decoder'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-134'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-134'>#</a>
|
|
</div>
|
|
<p>Configure the tracker to print the total train/validation loss</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">539</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">"loss.total.*"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
|
<span class="lineno">540</span>
|
|
<span class="lineno">541</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-135'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-135'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">543</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span>
|
|
<span class="lineno">544</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span>
|
|
<span class="lineno">545</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-136'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-136'>#</a>
|
|
</div>
|
|
<p>Move <code>data</code> and <code>mask</code> to device and swap the sequence and batch dimensions.
|
|
<code>data</code> will have shape <code>[seq_len, batch_size, 5]</code> and
|
|
<code>mask</code> will have shape <code>[seq_len, batch_size]</code>.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">550</span> <span class="n">data</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="lineno">551</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-137'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-137'>#</a>
|
|
</div>
|
|
<p>Increment step in training mode</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">554</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
|
|
<span class="lineno">555</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-138'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-138'>#</a>
|
|
</div>
|
|
<p>Encode the sequence of strokes</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">558</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">"encoder"</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-139'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-139'>#</a>
|
|
</div>
|
|
<p>Get $z$, $\mu$, and $\hat{\sigma}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">560</span> <span class="n">z</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">sigma_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-140'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-140'>#</a>
|
|
</div>
|
|
<p>Decode the mixture of distributions and $\hat{q}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">563</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">"decoder"</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-141'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-141'>#</a>
|
|
</div>
|
|
<p>Concatenate $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">565</span> <span class="n">z_stack</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="lineno">566</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">data</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">z_stack</span><span class="p">],</span> <span class="mi">2</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-142'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-142'>#</a>
|
|
</div>
|
|
<p>Get mixture of distributions and $\hat{q}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">568</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-143'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-143'>#</a>
|
|
</div>
|
|
<p>Compute the loss</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">571</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">'loss'</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-144'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-144'>#</a>
|
|
</div>
|
|
<p>$L_{KL}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">573</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss</span><span class="p">(</span><span class="n">sigma_hat</span><span class="p">,</span> <span class="n">mu</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-145'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-145'>#</a>
|
|
</div>
|
|
<p>$L_R$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">575</span> <span class="n">reconstruction_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reconstruction_loss</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="n">dist</span><span class="p">,</span> <span class="n">q_logits</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-146'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-146'>#</a>
|
|
</div>
|
|
<p>$Loss = L_R + w_{KL} L_{KL}$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">577</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">reconstruction_loss</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss_weight</span> <span class="o">*</span> <span class="n">kl_loss</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-147'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-147'>#</a>
|
|
</div>
|
|
<p>Track losses</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">580</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">"loss.kl."</span><span class="p">,</span> <span class="n">kl_loss</span><span class="p">)</span>
|
|
<span class="lineno">581</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">"loss.reconstruction."</span><span class="p">,</span> <span class="n">reconstruction_loss</span><span class="p">)</span>
|
|
<span class="lineno">582</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">"loss.total."</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-148'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-148'>#</a>
|
|
</div>
|
|
<p>Only if we are in training state</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">585</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-149'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-149'>#</a>
|
|
</div>
|
|
<p>Run optimizer</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">587</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">'optimize'</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-150'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-150'>#</a>
|
|
</div>
|
|
<p>Set <code>grad</code> to zero</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">589</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-151'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-151'>#</a>
|
|
</div>
|
|
<p>Compute gradients</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">591</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-152'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-152'>#</a>
|
|
</div>
|
|
<p>Log model parameters and gradients</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">593</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
|
|
<span class="lineno">594</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">encoder</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-153'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-153'>#</a>
|
|
</div>
|
|
<p>Clip gradients</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">596</span> <span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_clip</span><span class="p">)</span>
|
|
<span class="lineno">597</span> <span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_clip</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-154'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-154'>#</a>
|
|
</div>
|
|
<p>Optimize</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">599</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
|
<span class="lineno">600</span>
|
|
<span class="lineno">601</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-155'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-155'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">603</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-156'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-156'>#</a>
|
|
</div>
|
|
<p>Randomly pick a sample from validation dataset to encoder</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">605</span> <span class="n">data</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">))]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-157'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-157'>#</a>
|
|
</div>
|
|
<p>Add batch dimension and move it to device</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">607</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-158'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-158'>#</a>
|
|
</div>
|
|
<p>Sample</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">609</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-159'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-159'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">612</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
|
|
<span class="lineno">613</span> <span class="n">configs</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span>
|
|
<span class="lineno">614</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"sketch_rnn"</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-160'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-160'>#</a>
|
|
</div>
|
|
<p>Pass a dictionary of configurations</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">617</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">,</span> <span class="p">{</span>
|
|
<span class="lineno">618</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-161'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-161'>#</a>
|
|
</div>
|
|
<p>We use a learning rate of <code>1e-3</code> because we can see results faster.
|
|
Paper had suggested <code>1e-4</code>.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">621</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">1e-3</span><span class="p">,</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-162'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-162'>#</a>
|
|
</div>
|
|
<p>Name of the dataset</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">623</span> <span class="s1">'dataset_name'</span><span class="p">:</span> <span class="s1">'bicycle'</span><span class="p">,</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-163'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-163'>#</a>
|
|
</div>
|
|
<p>Number of inner iterations within an epoch to switch between training, validation and sampling.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">625</span> <span class="s1">'inner_iterations'</span><span class="p">:</span> <span class="mi">10</span>
|
|
<span class="lineno">626</span> <span class="p">})</span>
|
|
<span class="lineno">627</span>
|
|
<span class="lineno">628</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-164'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-164'>#</a>
|
|
</div>
|
|
<p>Run the experiment</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">630</span> <span class="n">configs</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
|
|
<span class="lineno">631</span>
|
|
<span class="lineno">632</span>
|
|
<span class="lineno">633</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"__main__"</span><span class="p">:</span>
|
|
<span class="lineno">634</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
|
</script>
|
|
<!-- MathJax configuration -->
|
|
<script type="text/x-mathjax-config">
|
|
MathJax.Hub.Config({
|
|
tex2jax: {
|
|
inlineMath: [ ['$','$'] ],
|
|
displayMath: [ ['$$','$$'] ],
|
|
processEscapes: true,
|
|
processEnvironments: true
|
|
},
|
|
// Center justify equations in code and markdown cells. Elsewhere
|
|
// we use CSS to left justify single line equations in code cells.
|
|
displayAlign: 'center',
|
|
"HTML-CSS": { fonts: ["TeX"] }
|
|
});
|
|
</script>
|
|
<script>
|
|
function handleImages() {
|
|
var images = document.querySelectorAll('p>img')
|
|
|
|
console.log(images);
|
|
for (var i = 0; i < images.length; ++i) {
|
|
handleImage(images[i])
|
|
}
|
|
}
|
|
|
|
function handleImage(img) {
|
|
img.parentElement.style.textAlign = 'center'
|
|
|
|
var modal = document.createElement('div')
|
|
modal.id = 'modal'
|
|
|
|
var modalContent = document.createElement('div')
|
|
modal.appendChild(modalContent)
|
|
|
|
var modalImage = document.createElement('img')
|
|
modalContent.appendChild(modalImage)
|
|
|
|
var span = document.createElement('span')
|
|
span.classList.add('close')
|
|
span.textContent = 'x'
|
|
modal.appendChild(span)
|
|
|
|
img.onclick = function () {
|
|
console.log('clicked')
|
|
document.body.appendChild(modal)
|
|
modalImage.src = img.src
|
|
}
|
|
|
|
span.onclick = function () {
|
|
document.body.removeChild(modal)
|
|
}
|
|
}
|
|
|
|
handleImages()
|
|
</script>
|
|
</body>
|
|
</html> |