This commit is contained in:
Varuna Jayasiri
2022-05-23 22:26:39 +05:30
committed by GitHub
parent 5d3348d0c3
commit 6a41c82b30
16 changed files with 1947 additions and 21 deletions

View File

@ -0,0 +1,833 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Training a transformer with FTA in FFN on Tiny Shakespeare."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Fuzzy Tiling Activation Experiment"/>
<meta name="twitter:description" content="Training a transformer with FTA in FFN on Tiny Shakespeare."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/activations/fta/experiment.html"/>
<meta property="og:title" content="Fuzzy Tiling Activation Experiment"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Fuzzy Tiling Activation Experiment"/>
<meta property="og:description" content="Training a transformer with FTA in FFN on Tiny Shakespeare."/>
<title>Fuzzy Tiling Activation Experiment</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/activations/fta/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">activations</a>
<a class="parent" href="index.html">fta</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/activations/fta/experiment.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="index.html">Fuzzy Tiling Activation</a> Experiment</h1>
<p>Here we train a transformer that uses <a href="index.html">Fuzzy Tiling Activation</a> in the <a href="../../transformers/feed_forward.html">Feed-Forward Network</a>. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.</p>
<p>However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.</p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step"><img alt="Open In Comet" src="https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span><span></span><span class="kn">import</span> <span class="nn">copy</span>
<span class="lineno">23</span>
<span class="lineno">24</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">25</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">26</span>
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_nn.activations.fta</span> <span class="kn">import</span> <span class="n">FTA</span>
<span class="lineno">31</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span><span class="p">,</span> <span class="n">TransformerLayer</span>
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>FFN module with <a href="index.html">FTA</a> activation</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">36</span><span class="k">class</span> <span class="nc">FeedForwardFTA</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the number of features in a token embedding </li>
<li><code class="highlight"><span></span><span class="n">d_ff</span></code>
is the number of features in the hidden layer of the FFN </li>
<li><code class="highlight"><span></span><span class="n">activation</span></code>
is FTA activation module </li>
<li><code class="highlight"><span></span><span class="n">dropout</span></code>
is dropout probability for the hidden layer</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">42</span> <span class="n">activation</span><span class="p">:</span> <span class="n">FTA</span><span class="p">,</span>
<span class="lineno">43</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Layer one parameterized by weight <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> and bias <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Layer two parameterized by weight <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> and bias <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span> <span class="o">*</span> <span class="n">activation</span><span class="o">.</span><span class="n">expansion_factor</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Hidden layer dropout </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Activation function <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord coloredeq eqe" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Apply dropout </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<h2>Auto-Regressive model</h2>
<p>This is an autoregressive transformer model that uses Feed-Forward Networks with (Fuzzy Tiling Activations)(index.html).</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span><span class="k">class</span> <span class="nc">AutoregressiveTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<ul><li><code>n_tokens</code> is the number of tokens in the vocabulary </li>
<li><code>d_model</code> is the embedding size </li>
<li><code>n_layers</code> is the number of transformer layers </li>
<li><code>layer</code> is the layer. We use <code class="highlight"><span></span><span class="n">n_layers</span></code>
copies of this for the transformer.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Transformer with <code class="highlight"><span></span><span class="n">n_layers</span></code>
layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span><span class="p">)])</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Token embedding layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="bp">self</span><span class="o">.</span><span class="n">emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Readout layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="bp">self</span><span class="o">.</span><span class="n">readout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>The mask will be initialized on the first call </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<ul><li><code>x</code> are the input tokens of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Create auto-regressive mask </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Subsequent mask, will mask out tokens from seeing future tokens </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="n">subsequent_mask</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Get the token embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Transformer encoder </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span><span class="p">:</span>
<span class="lineno">109</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Get logits </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">readout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Return results </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<h2>Configurations</h2>
<p>This inherits from <a href="../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs"><code class="highlight"><span></span><span class="n">NLPAutoRegressionConfigs</span></code>
</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">model</span><span class="p">:</span> <span class="n">AutoregressiveTransformer</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Number of layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span> and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span></span></span></span> for DeepNorm </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">deep_norm_alpha</span><span class="p">:</span> <span class="nb">float</span>
<span class="lineno">133</span> <span class="n">deep_norm_beta</span><span class="p">:</span> <span class="nb">float</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Number of heads in the attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Embedding size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Size of each attention head </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Feed forward layer size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>FTA </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">fta_lower_limit</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.</span>
<span class="lineno">146</span> <span class="n">fta_upper_limit</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">+</span><span class="mf">1.</span>
<span class="lineno">147</span> <span class="n">fta_delta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span>
<span class="lineno">148</span> <span class="n">fta_eta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.05</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<h4>Initialize the model</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">152</span><span class="k">def</span> <span class="nf">_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Create FTA activation module </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">fta</span> <span class="o">=</span> <span class="n">FTA</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">fta_lower_limit</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">fta_upper_limit</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">fta_delta</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">fta_eta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Create the transformer. We re-use <a href="../../transformers/models.html#TransformerLayer"><code class="highlight"><span></span><span class="n">TransformerLayer</span></code>
</a> and <a href="../../transformers/mha.html"><code class="highlight"><span></span><span class="n">MultiHeadAttention</span></code>
</a> implementations. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">m</span> <span class="o">=</span> <span class="n">AutoregressiveTransformer</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">,</span>
<span class="lineno">163</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
<span class="lineno">164</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">FeedForwardFTA</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
<span class="lineno">165</span> <span class="n">d_ff</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span>
<span class="lineno">166</span> <span class="n">activation</span><span class="o">=</span><span class="n">fta</span><span class="p">,</span>
<span class="lineno">167</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span>
<span class="lineno">168</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
<span class="lineno">169</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="mf">0.0</span><span class="p">),</span>
<span class="lineno">170</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="mf">0.0</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Move to the device </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">return</span> <span class="n">m</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<h4>Create and run the experiment</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Create experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;fta&quot;</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;screen&#39;</span><span class="p">,</span> <span class="s1">&#39;comet&#39;</span><span class="p">,</span> <span class="s1">&#39;labml&#39;</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Create configs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Override configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Use character level tokenizer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="s1">&#39;tokenizer&#39;</span><span class="p">:</span> <span class="s1">&#39;character&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Prompt separator is blank </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="s1">&#39;prompt_separator&#39;</span><span class="p">:</span> <span class="s1">&#39;&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Starting prompt for sampling </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="s1">&#39;prompt&#39;</span><span class="p">:</span> <span class="s1">&#39;It is &#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Use Tiny Shakespeare dataset </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="s1">&#39;text&#39;</span><span class="p">:</span> <span class="s1">&#39;tiny_shakespeare&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>Use a context size of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">256</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>Train for 32 epochs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Batch size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">16</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Switch between training and validation for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">10</span></span></span></span> times per epoch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Adam optimizer with no warmup </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">206</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">3e-4</span><span class="p">,</span>
<span class="lineno">207</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>Set model(s) for saving and loading </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Start the experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Run training </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">220</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -0,0 +1,402 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="PyTorch implementation and tutorial of Fuzzy Tiling Activations from the paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Fuzzy Tiling Activations"/>
<meta name="twitter:description" content="PyTorch implementation and tutorial of Fuzzy Tiling Activations from the paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/activations/fta/index.html"/>
<meta property="og:title" content="Fuzzy Tiling Activations"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Fuzzy Tiling Activations"/>
<meta property="og:description" content="PyTorch implementation and tutorial of Fuzzy Tiling Activations from the paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online."/>
<title>Fuzzy Tiling Activations</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/activations/fta/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">activations</a>
<a class="parent" href="index.html">fta</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/activations/fta/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Fuzzy Tiling Activations (FTA)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of <a href="https://papers.labml.ai/paper/aca66d8edc8911eba3db37f65e372566">Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online</a>.</p>
<p>Fuzzy tiling activations are a form of sparse activations based on binning.</p>
<p>Binning is classification of a scalar value into a bin based on intervals. One problem with binning is that it gives zero gradients for most values (except at the boundary of bins). The other is that binning loses precision if the bin intervals are large.</p>
<p>FTA overcomes these disadvantages. Instead of hard boundaries like in Tiling Activations, FTA uses soft boundaries between bins. This gives non-zero gradients for all or a wide range of values. And also doesn&#x27;t lose precision since it&#x27;s captured in partial values.</p>
<h4>Tiling Activations</h4>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44444em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span></span></span></span> is the tiling vector,</p>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44444em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner"></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mclose">)</span></span></span></span></span></p>
<p>where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mclose">]</span></span></span></span> is the input range, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span> is the bin size, and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqm" style=""><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style=""></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span></span></span></span></span> is divisible by <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span>.</p>
<p>Tiling activation is,</p>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">ϕ</span><span class="mopen">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mclose">)</span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span></p>
<p>where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"></span><span class="mclose">)</span></span></span></span> is the indicator function which gives <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span></span></span></span> if the input is positive and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span></span></span></span> otherwise.</p>
<p>Note that tiling activation gives zero gradients because it has hard boundaries.</p>
<h4>Fuzzy Tiling Activations</h4>
<p>The fuzzy indicator function,</p>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">η</span><span class="mpunct mtight">,</span><span class="mord mtight">+</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mclose">)</span></span></span></span></span></p>
<p>which increases linearly from <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span></span></span></span> to <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span></span></span></span> when <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.78041em;vertical-align:-0.13597em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.5782em;vertical-align:-0.0391em;"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">&lt;</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> and is equal to <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span></span></span></span> for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8304100000000001em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">x</span></span></span></span>. <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> is a hyper-parameter.</p>
<p>FTA uses this to create soft boundaries between bins.</p>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="">ϕ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">η</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">η</span><span class="mpunct mtight">,</span><span class="mord mtight">+</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span></p>
<p><a href="experiment.html">Here&#x27;s a simple experiment</a> that uses FTA in a transformer.</p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step"><img alt="Open In Comet" src="https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">63</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>Fuzzy Tiling Activations (FTA)</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span><span class="k">class</span> <span class="nc">FTA</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code>lower_limit</code> is the lower limit <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span></span></span></span> </li>
<li><code>upper_limit</code> is the upper limit <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span></span></span></span> </li>
<li><code>delta</code> is the bin size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span> </li>
<li><code>eta</code> is the parameter <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> that detemines the softness of the boundaries.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lower_limit</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">upper_limit</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">delta</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">eta</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Initialize tiling vector <span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44444em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner"></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mclose">)</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">c</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">lower_limit</span><span class="p">,</span> <span class="n">upper_limit</span><span class="p">,</span> <span class="n">delta</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>The input vector expands by a factor equal to the number of bins <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="margin-right:0.03785em">δ</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqm" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="">u</span></span><span class="mbin mtight" style=""></span><span class="mord mtight coloredeq eqq" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">expansion_factor</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">c</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">delta</span> <span class="o">=</span> <span class="n">delta</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="bp">self</span><span class="o">.</span><span class="n">eta</span> <span class="o">=</span> <span class="n">eta</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h4>Fuzzy indicator function</h4>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">η</span><span class="mpunct mtight">,</span><span class="mord mtight">+</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mclose">)</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="k">def</span> <span class="nf">fuzzy_i_plus</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">&lt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">eta</span><span class="p">)</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="p">(</span><span class="n">x</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">eta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Add another dimension of size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span></span></span></span>. We will expand this into bins. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">z</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="">ϕ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">η</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">η</span><span class="mpunct mtight">,</span><span class="mord mtight">+</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">z</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">fuzzy_i_plus</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">c</span> <span class="o">-</span> <span class="n">z</span><span class="p">,</span> <span class="nb">min</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">z</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">delta</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">c</span><span class="p">,</span> <span class="nb">min</span><span class="o">=</span><span class="mf">0.</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Reshape back to original number of dimensions. The last dimension size gets expanded by the number of bins, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="margin-right:0.03785em">δ</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqm" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="">u</span></span><span class="mbin mtight" style=""></span><span class="mord mtight coloredeq eqq" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span>. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">return</span> <span class="n">z</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<h4>Code to test the FTA module</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Initialize </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">a</span> <span class="o">=</span> <span class="n">FTA</span><span class="p">(</span><span class="o">-</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Print <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44444em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">inspect</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">c</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Print number of bins <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="margin-right:0.03785em">δ</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqm" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="">u</span></span><span class="mbin mtight" style=""></span><span class="mord mtight coloredeq eqq" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">inspect</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">expansion_factor</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Input <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.1</span><span class="p">,</span> <span class="mf">2.2</span><span class="p">,</span> <span class="mf">3.3</span><span class="p">,</span> <span class="mf">4.4</span><span class="p">,</span> <span class="mf">5.5</span><span class="p">,</span> <span class="mf">6.6</span><span class="p">,</span> <span class="mf">7.7</span><span class="p">,</span> <span class="mf">8.8</span><span class="p">,</span> <span class="mf">9.</span><span class="p">,</span> <span class="mf">10.</span><span class="p">,</span> <span class="mf">11.</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Print <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">inspect</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Print <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="">ϕ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">η</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mclose" style="">)</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">inspect</span><span class="p">(</span><span class="n">a</span><span class="p">(</span><span class="n">z</span><span class="p">))</span>
<span class="lineno">129</span>
<span class="lineno">130</span>
<span class="lineno">131</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">132</span> <span class="n">_test</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -3,24 +3,24 @@
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="description" content="A set of PyTorch implementations/tutorials related to neural network activations"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="__init__.py"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:title" content="Neural Network Activation Functions"/>
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials related to neural network activations"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/activations/index.html"/>
<meta property="og:title" content="__init__.py"/>
<meta property="og:title" content="Neural Network Activation Functions"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="__init__.py"/>
<meta property="og:description" content=""/>
<meta property="og:title" content="Neural Network Activation Functions"/>
<meta property="og:description" content="A set of PyTorch implementations/tutorials related to neural network activations"/>
<title>__init__.py</title>
<title>Neural Network Activation Functions</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/activations/index.html"/>
@ -64,14 +64,17 @@
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Neural Networks Activations</h1>
<ul><li><a href="fta/index.html">Fuzzy Tiling Activations</a> </li>
<li>🚧 <a href="swish/index.html">Swish</a></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">from</span> <span class="nn">.swish</span> <span class="kn">import</span> <span class="n">Swish</span></pre></div>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">from</span> <span class="nn">.swish</span> <span class="kn">import</span> <span class="n">Swish</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -139,6 +139,8 @@
<ul><li><a href="adaptive_computation/ponder_net/index.html">PonderNet</a></li></ul>
<h4><a href="uncertainty/index.html">Uncertainty</a></h4>
<ul><li><a href="uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li></ul>
<h4><a href="activations/index.html">Activations</a></h4>
<ul><li><a href="activations/fta/index.html">Fuzzy Tiling Activations</a></li></ul>
<h2>Highlighted Research Paper PDFs</h2>
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>

View File

@ -69,9 +69,8 @@
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>DeepNorm Experiment</h1>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a> <a href="https://www.comet.ml/labml/deep-norm/61d817f80ff143c8825fba4aacd431d4?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=step"><img alt="Open In Comet" src="https://images.labml.ai/images/comet.svg?experiment=deep_norm&file=experiment"></a></p>
<h1><a href="index.html">DeepNorm</a> Experiment</h1>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">copy</span>

View File

@ -85,7 +85,21 @@
<url>
<loc>https://nn.labml.ai/activations/index.html</loc>
<lastmod>2021-01-25T16:30:00+00:00</lastmod>
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/activations/fta/index.html</loc>
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/activations/fta/experiment.html</loc>
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -197,7 +211,7 @@
<url>
<loc>https://nn.labml.ai/normalization/deep_norm/experiment.html</loc>
<lastmod>2022-04-23T16:30:00+00:00</lastmod>
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -295,7 +309,7 @@
<url>
<loc>https://nn.labml.ai/index.html</loc>
<lastmod>2022-05-03T16:30:00+00:00</lastmod>
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -589,7 +603,7 @@
<url>
<loc>https://nn.labml.ai/transformers/rope/index.html</loc>
<lastmod>2022-02-23T16:30:00+00:00</lastmod>
<lastmod>2022-04-05T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>

File diff suppressed because one or more lines are too long

View File

@ -112,6 +112,10 @@ Solving games with incomplete information such as poker with CFR.
* [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html)
#### ✨ [Activations](activations/index.html)
* [Fuzzy Tiling Activations](activations/fta/index.html)
## Highlighted Research Paper PDFs
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)

View File

@ -1 +1,14 @@
"""
---
title: Neural Network Activation Functions
summary: >
A set of PyTorch implementations/tutorials related to neural network activations
---
# Neural Networks Activations
* [Fuzzy Tiling Activations](fta/index.html)
* 🚧 [Swish](swish/index.html)
"""
from .swish import Swish

View File

@ -0,0 +1,132 @@
"""
---
title: Fuzzy Tiling Activations
summary: >
PyTorch implementation and tutorial of Fuzzy Tiling Activations from the
paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online.
---
# Fuzzy Tiling Activations (FTA)
This is a [PyTorch](https://pytorch.org) implementation/tutorial of
[Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online](https://papers.labml.ai/paper/aca66d8edc8911eba3db37f65e372566).
Fuzzy tiling activations are a form of sparse activations based on binning.
Binning is classification of a scalar value into a bin based on intervals.
One problem with binning is that it gives zero gradients for most values (except at the boundary of bins).
The other is that binning loses precision if the bin intervals are large.
FTA overcomes these disadvantages.
Instead of hard boundaries like in Tiling Activations, FTA uses soft boundaries
between bins.
This gives non-zero gradients for all or a wide range of values.
And also doesn't lose precision since it's captured in partial values.
#### Tiling Activations
$\mathbf{c}$ is the tiling vector,
$$\mathbf{c} = (l, l + \delta, l + 2 \delta, \dots, u - 2 \delta, u - \delta)$$
where $[l, u]$ is the input range, $\delta$ is the bin size, and $u - l$ is divisible by $\delta$.
Tiling activation is,
$$\phi(z) = 1 - I_+ \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}) \big)$$
where $I_+(\cdot)$ is the indicator function which gives $1$ if the input is positive and $0$ otherwise.
Note that tiling activation gives zero gradients because it has hard boundaries.
#### Fuzzy Tiling Activations
The fuzzy indicator function,
$$I_{\eta,+}(x) = I_+(\eta - x) x + I_+ (x - \eta)$$
which increases linearly from $0$ to $1$ when $0 \le x \lt \eta$
and is equal to $1$ for $\eta \le x$.
$\eta$ is a hyper-parameter.
FTA uses this to create soft boundaries between bins.
$$\phi_\eta(z) = 1 - I_{\eta,+} \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}, 0) \big)$$
[Here's a simple experiment](experiment.html) that uses FTA in a transformer.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)
[![Open In Comet](https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model)](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)
"""
import torch
from torch import nn
class FTA(nn.Module):
"""
### Fuzzy Tiling Activations (FTA)
"""
def __init__(self, lower_limit: float, upper_limit: float, delta: float, eta: float):
"""
:param lower_limit: is the lower limit $l$
:param upper_limit: is the upper limit $u$
:param delta: is the bin size $\delta$
:param eta: is the parameter $\eta$ that detemines the softness of the boundaries.
"""
super().__init__()
# Initialize tiling vector
# $$\mathbf{c} = (l, l + \delta, l + 2 \delta, \dots, u - 2 \delta, u - \delta)$$
self.c = nn.Parameter(torch.arange(lower_limit, upper_limit, delta), requires_grad=False)
# The input vector expands by a factor equal to the number of bins $\frac{u - l}{\delta}$
self.expansion_factor = len(self.c)
# $\delta$
self.delta = delta
# $\eta$
self.eta = eta
def fuzzy_i_plus(self, x: torch.Tensor):
"""
#### Fuzzy indicator function
$$I_{\eta,+}(x) = I_+(\eta - x) x + I_+ (x - \eta)$$
"""
return (x <= self.eta) * x + (x > self.eta)
def forward(self, z: torch.Tensor):
# Add another dimension of size $1$.
# We will expand this into bins.
z = z.view(*z.shape, 1)
# $$\phi_\eta(z) = 1 - I_{\eta,+} \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}, 0) \big)$$
z = 1. - self.fuzzy_i_plus(torch.clip(self.c - z, min=0.) + torch.clip(z - self.delta - self.c, min=0.))
# Reshape back to original number of dimensions.
# The last dimension size gets expanded by the number of bins, $\frac{u - l}{\delta}$.
return z.view(*z.shape[:-2], -1)
def _test():
"""
#### Code to test the FTA module
"""
from labml.logger import inspect
# Initialize
a = FTA(-10, 10, 2., 0.5)
# Print $\mathbf{c}$
inspect(a.c)
# Print number of bins $\frac{u - l}{\delta}$
inspect(a.expansion_factor)
# Input $z$
z = torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9., 10., 11.])
# Print $z$
inspect(z)
# Print $\phi_\eta(z)$
inspect(a(z))
if __name__ == '__main__':
_test()

View File

@ -0,0 +1,299 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "AYV_dMVDxyc2",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)\n",
"[![Open In Comet](https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model)](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)\n",
"\n",
"## [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html)\n",
"\n",
"Here we train a transformer that uses [Fuzzy Tiling Activation](https://nn.labml.ai/activations/fta/index.html) in the\n",
"[Feed-Forward Network](https://nn.labml.ai/transformers/feed_forward.html).\n",
"We use it for a language model and train it on Tiny Shakespeare dataset\n",
"for demonstration.\n",
"However, this is probably not the ideal task for FTA, and we\n",
"believe FTA is more suitable for modeling data with continuous variables."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AahG_i2y5tY9",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Install the packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZCzmCrAIVg0L",
"outputId": "cf107fb2-4d50-4c67-af34-367624553421",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"!pip install labml-nn comet_ml --quiet"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Enable [Comet](https://www.comet.ml)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"#@markdown Select in order to enable logging this experiment to [Comet](https://www.comet.ml).\n",
"use_comet = False #@param {type:\"boolean\"}\n",
"\n",
"if use_comet:\n",
" import comet_ml\n",
" comet_ml.init(project_name='fta')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SE2VUQ6L5zxI",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"from labml import experiment\n",
"from labml.configs import option\n",
"from labml_nn.activations.fta.experiment import Configs"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Create an experiment"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"experiment.create(name=\"fta\", writers={\"screen\", \"comet\"} if use_comet else {'screen'})"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Configurations"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"conf = Configs()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"Set experiment configurations and assign a configurations dictionary to override configurations"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"experiment.configs(conf, {\n",
" 'tokenizer': 'character',\n",
" 'prompt_separator': '',\n",
" 'prompt': 'It is ',\n",
" 'text': 'tiny_shakespeare',\n",
"\n",
" 'seq_len': 256,\n",
" 'epochs': 32,\n",
" 'batch_size': 16,\n",
" 'inner_iterations': 10,\n",
"\n",
" 'optimizer.optimizer': 'Adam',\n",
" 'optimizer.learning_rate': 3e-4,\n",
"})"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "EvI7MtgJ61w5",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Set PyTorch models for loading and saving"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 255
},
"id": "GDlt7dp-5ALt",
"outputId": "e7548e8f-c541-4618-dc5a-1597cae42003",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"experiment.add_pytorch_models({'model': conf.model})"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KJZRf8527GxL",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Start the experiment and run the training loop."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "aIAWo7Fw5DR8",
"outputId": "db979785-bfe3-4eda-d3eb-8ccbe61053e5",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# Start the experiment\n",
"with experiment.start():\n",
" conf.run()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "FTA",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,220 @@
"""
---
title: Fuzzy Tiling Activation Experiment
summary: >
Training a transformer with FTA in FFN on Tiny Shakespeare.
---
# [Fuzzy Tiling Activation](index.html) Experiment
Here we train a transformer that uses [Fuzzy Tiling Activation](index.html) in the
[Feed-Forward Network](../../transformers/feed_forward.html).
We use it for a language model and train it on Tiny Shakespeare dataset
for demonstration.
However, this is probably not the ideal task for FTA, and we
believe FTA is more suitable for modeling data with continuous variables.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)
[![Open In Comet](https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model)](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)
"""
import copy
import torch
import torch.nn as nn
from labml import experiment
from labml.configs import option
from labml_helpers.module import Module
from labml_nn.activations.fta import FTA
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers import MultiHeadAttention, TransformerLayer
from labml_nn.transformers.utils import subsequent_mask
class FeedForwardFTA(nn.Module):
"""
## FFN module with [FTA](index.html) activation
"""
def __init__(self, d_model: int, d_ff: int,
activation: FTA,
dropout: float = 0.1):
"""
* `d_model` is the number of features in a token embedding
* `d_ff` is the number of features in the hidden layer of the FFN
* `activation` is FTA activation module
* `dropout` is dropout probability for the hidden layer
"""
super().__init__()
# Layer one parameterized by weight $W_1$ and bias $b_1$
self.layer1 = nn.Linear(d_model, d_ff)
# Layer two parameterized by weight $W_1$ and bias $b_1$
self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)
# Hidden layer dropout
self.dropout = nn.Dropout(dropout)
# Activation function $f$
self.activation = activation
def forward(self, x: torch.Tensor):
# $f(x W_1 + b_1)$
x = self.activation(self.layer1(x))
# Apply dropout
x = self.dropout(x)
#
return self.layer2(x)
class AutoregressiveTransformer(Module):
"""
## Auto-Regressive model
This is an autoregressive transformer model that uses Feed-Forward Networks with
(Fuzzy Tiling Activations)(index.html).
"""
def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
"""
:param n_tokens: is the number of tokens in the vocabulary
:param d_model: is the embedding size
:param n_layers: is the number of transformer layers
:param layer: is the layer. We use `n_layers` copies of this for the transformer.
"""
super().__init__()
# Transformer with `n_layers` layers
self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
# Token embedding layer
self.emb = nn.Embedding(n_tokens, d_model)
# Readout layer
self.readout = nn.Linear(d_model, n_tokens)
# The mask will be initialized on the first call
self.mask = None
def forward(self, x: torch.Tensor):
"""
:param x: are the input tokens of shape `[seq_len, batch_size]`
"""
# Create auto-regressive mask
if self.mask is None or self.mask.size(0) != len(x):
# Subsequent mask, will mask out tokens from seeing future tokens
self.mask = subsequent_mask(len(x)).to(x.device)
# Get the token embeddings
x = self.emb(x)
# Transformer encoder
for layer in self.transformer_layers:
x = layer(x=x, mask=self.mask)
# Get logits
x = self.readout(x)
# Return results
return x, None
class Configs(NLPAutoRegressionConfigs):
"""
## Configurations
This inherits from
[`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs)
"""
# Model
model: AutoregressiveTransformer
# Number of layers
n_layers: int = 4
# $\alpha$ and $\beta$ for DeepNorm
deep_norm_alpha: float
deep_norm_beta: float
# Number of heads in the attention
n_heads: int = 4
# Embedding size
d_model: int = 256
# Size of each attention head
d_k: int = 16
# Feed forward layer size
d_ff: int = 256
# FTA
fta_lower_limit: float = -1.
fta_upper_limit: float = +1.
fta_delta: float = 0.2
fta_eta: float = 0.05
@option(Configs.model)
def _model(c: Configs):
"""
#### Initialize the model
"""
# Create FTA activation module
fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)
# Create the transformer.
# We re-use [`TransformerLayer`](../../transformers/models.html#TransformerLayer) and
# [`MultiHeadAttention`](../../transformers/mha.html) implementations.
m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
TransformerLayer(d_model=c.d_model,
feed_forward=FeedForwardFTA(d_model=c.d_model,
d_ff=c.d_ff,
activation=fta,
dropout=0.1),
self_attn=MultiHeadAttention(c.n_heads, c.d_model,
dropout_prob=0.0),
dropout_prob=0.0))
# Move to the device
return m.to(c.device)
def main():
"""
#### Create and run the experiment
"""
# Create experiment
experiment.create(name="fta", writers={'screen', 'comet', 'labml'})
# Create configs
conf = Configs()
# Override configurations
experiment.configs(conf, {
# Use character level tokenizer
'tokenizer': 'character',
# Prompt separator is blank
'prompt_separator': '',
# Starting prompt for sampling
'prompt': 'It is ',
# Use Tiny Shakespeare dataset
'text': 'tiny_shakespeare',
# Use a context size of $256$
'seq_len': 256,
# Train for 32 epochs
'epochs': 32,
# Batch size $16$
'batch_size': 16,
# Switch between training and validation for $10$ times per epoch
'inner_iterations': 10,
# Adam optimizer with no warmup
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 3e-4,
})
# Set model(s) for saving and loading
experiment.add_pytorch_models({'model': conf.model})
# Start the experiment
with experiment.start():
# Run training
conf.run()
#
if __name__ == '__main__':
main()

View File

@ -5,7 +5,7 @@ summary: >
Training a DeepNorm transformer on Tiny Shakespeare.
---
# DeepNorm Experiment
# [DeepNorm](index.html) Experiment
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d)

View File

@ -141,7 +141,7 @@ class RotaryPositionalEmbeddings(nn.Module):
idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
# Concatenate so that for row $m$ we have
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta 0, m \theta 1, ..., m \theta_{\frac{d}{2}}]$
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$

View File

@ -115,6 +115,11 @@ Solving games with incomplete information such as poker with CFR.
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
#### ✨ [Activations](https://nn.labml.ai/activations/index.html)
* [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html)
## Highlighted Research Paper PDFs
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)

View File

@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
version='0.4.121',
version='0.4.122',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="🧑‍🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",