Files
2024-06-21 19:35:22 +05:30

276 lines
17 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This creates data for Parity Task from the paper Adaptive Computation Time for Recurrent Neural Networks"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Parity Task"/>
<meta name="twitter:description" content="This creates data for Parity Task from the paper Adaptive Computation Time for Recurrent Neural Networks"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/adaptive_computation/parity.html"/>
<meta property="og:title" content="Parity Task"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Parity Task"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Parity Task"/>
<meta property="og:description" content="This creates data for Parity Task from the paper Adaptive Computation Time for Recurrent Neural Networks"/>
<title>Parity Task</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/adaptive_computation/parity.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">adaptive_computation</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/adaptive_computation/parity.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Parity Task</h1>
<p>This creates data for Parity Task from the paper <a href="https://arxiv.org/abs/1603.08983">Adaptive Computation Time for Recurrent Neural Networks</a>.</p>
<p>The input of the parity task is a vector with <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0</span></span></span></span></span>&#x27;s <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">1</span></span></span></span></span></span>&#x27;s and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""></span><span class="mord" style=""><span class="mord coloredeq eqc" style="">1</span></span></span></span></span></span></span>&#x27;s. The output is the parity of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">1</span></span></span></span></span></span>&#x27;s - one if there is an odd number of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">1</span></span></span></span></span></span>&#x27;s and zero otherwise. The input is generated by making a random number of elements in the vector either <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">1</span></span></span></span></span></span> or <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""></span><span class="mord" style=""><span class="mord coloredeq eqc" style="">1</span></span></span></span></span></span></span>&#x27;s.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">19</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span>
<span class="lineno">20</span>
<span class="lineno">21</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>Parity dataset</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">25</span><span class="k">class</span> <span class="nc">ParityDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_samples</span></code>
is the number of samples </li>
<li><code class="highlight"><span></span><span class="n">n_elems</span></code>
is the number of elements in the input vector</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">30</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_samples</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_elems</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span> <span class="o">=</span> <span class="n">n_samples</span>
<span class="lineno">36</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_elems</span> <span class="o">=</span> <span class="n">n_elems</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p> Size of the dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p> Generate a sample</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Empty vector </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_elems</span><span class="p">,))</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Number of non-zero elements - a random number between <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">1</span></span></span></span></span></span> and total number of elements </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">n_non_zero</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_elems</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,))</span><span class="o">.</span><span class="n">item</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Fill non-zero elements with <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">1</span></span></span></span></span></span>&#x27;s and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""></span><span class="mord" style=""><span class="mord coloredeq eqc" style="">1</span></span></span></span></span></span></span>&#x27;s </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">x</span><span class="p">[:</span><span class="n">n_non_zero</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">n_non_zero</span><span class="p">,))</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">-</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Randomly permute the elements </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">randperm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_elems</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>The parity </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">y</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">==</span> <span class="mf">1.</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">%</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>