mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
FTA (#115)
This commit is contained in:
833
docs/activations/fta/experiment.html
Normal file
833
docs/activations/fta/experiment.html
Normal file
@ -0,0 +1,833 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||||
|
<meta name="description" content="Training a transformer with FTA in FFN on Tiny Shakespeare."/>
|
||||||
|
|
||||||
|
<meta name="twitter:card" content="summary"/>
|
||||||
|
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||||
|
<meta name="twitter:title" content="Fuzzy Tiling Activation Experiment"/>
|
||||||
|
<meta name="twitter:description" content="Training a transformer with FTA in FFN on Tiny Shakespeare."/>
|
||||||
|
<meta name="twitter:site" content="@labmlai"/>
|
||||||
|
<meta name="twitter:creator" content="@labmlai"/>
|
||||||
|
|
||||||
|
<meta property="og:url" content="https://nn.labml.ai/activations/fta/experiment.html"/>
|
||||||
|
<meta property="og:title" content="Fuzzy Tiling Activation Experiment"/>
|
||||||
|
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||||
|
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||||
|
<meta property="og:type" content="object"/>
|
||||||
|
<meta property="og:title" content="Fuzzy Tiling Activation Experiment"/>
|
||||||
|
<meta property="og:description" content="Training a transformer with FTA in FFN on Tiny Shakespeare."/>
|
||||||
|
|
||||||
|
<title>Fuzzy Tiling Activation Experiment</title>
|
||||||
|
<link rel="shortcut icon" href="/icon.png"/>
|
||||||
|
<link rel="stylesheet" href="../../pylit.css?v=1">
|
||||||
|
<link rel="canonical" href="https://nn.labml.ai/activations/fta/experiment.html"/>
|
||||||
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||||
|
|
||||||
|
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||||
|
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||||
|
<script>
|
||||||
|
window.dataLayer = window.dataLayer || [];
|
||||||
|
|
||||||
|
function gtag() {
|
||||||
|
dataLayer.push(arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
gtag('js', new Date());
|
||||||
|
|
||||||
|
gtag('config', 'G-4V3HC8HBLH');
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id='container'>
|
||||||
|
<div id="background"></div>
|
||||||
|
<div class='section'>
|
||||||
|
<div class='docs'>
|
||||||
|
<p>
|
||||||
|
<a class="parent" href="/">home</a>
|
||||||
|
<a class="parent" href="../index.html">activations</a>
|
||||||
|
<a class="parent" href="index.html">fta</a>
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
|
||||||
|
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/activations/fta/experiment.py">
|
||||||
|
<img alt="Github"
|
||||||
|
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||||
|
style="max-width:100%;"/></a>
|
||||||
|
<a href="https://twitter.com/labmlai"
|
||||||
|
rel="nofollow">
|
||||||
|
<img alt="Twitter"
|
||||||
|
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||||
|
style="max-width:100%;"/></a>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-0'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-0'>#</a>
|
||||||
|
</div>
|
||||||
|
<h1><a href="index.html">Fuzzy Tiling Activation</a> Experiment</h1>
|
||||||
|
<p>Here we train a transformer that uses <a href="index.html">Fuzzy Tiling Activation</a> in the <a href="../../transformers/feed_forward.html">Feed-Forward Network</a>. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.</p>
|
||||||
|
<p>However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.</p>
|
||||||
|
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step"><img alt="Open In Comet" src="https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model"></a></p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">22</span><span></span><span class="kn">import</span> <span class="nn">copy</span>
|
||||||
|
<span class="lineno">23</span>
|
||||||
|
<span class="lineno">24</span><span class="kn">import</span> <span class="nn">torch</span>
|
||||||
|
<span class="lineno">25</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||||
|
<span class="lineno">26</span>
|
||||||
|
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
|
||||||
|
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
|
||||||
|
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
||||||
|
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_nn.activations.fta</span> <span class="kn">import</span> <span class="n">FTA</span>
|
||||||
|
<span class="lineno">31</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
|
||||||
|
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span><span class="p">,</span> <span class="n">TransformerLayer</span>
|
||||||
|
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-1'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-1'>#</a>
|
||||||
|
</div>
|
||||||
|
<h2>FFN module with <a href="index.html">FTA</a> activation</h2>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">36</span><span class="k">class</span> <span class="nc">FeedForwardFTA</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-2'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-2'>#</a>
|
||||||
|
</div>
|
||||||
|
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||||||
|
is the number of features in a token embedding </li>
|
||||||
|
<li><code class="highlight"><span></span><span class="n">d_ff</span></code>
|
||||||
|
is the number of features in the hidden layer of the FFN </li>
|
||||||
|
<li><code class="highlight"><span></span><span class="n">activation</span></code>
|
||||||
|
is FTA activation module </li>
|
||||||
|
<li><code class="highlight"><span></span><span class="n">dropout</span></code>
|
||||||
|
is dropout probability for the hidden layer</li></ul>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">41</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||||||
|
<span class="lineno">42</span> <span class="n">activation</span><span class="p">:</span> <span class="n">FTA</span><span class="p">,</span>
|
||||||
|
<span class="lineno">43</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-3'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-3'>#</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">50</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-4'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-4'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Layer one parameterized by weight <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> and bias <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-5'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-5'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Layer two parameterized by weight <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> and bias <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span> <span class="o">*</span> <span class="n">activation</span><span class="o">.</span><span class="n">expansion_factor</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-6'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-6'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Hidden layer dropout </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">56</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-7'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-7'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Activation function <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">58</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-8'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-8'>#</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">60</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-9'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-9'>#</a>
|
||||||
|
</div>
|
||||||
|
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord coloredeq eqe" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style=""><span class="mord mathnormal" style="">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-10'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-10'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Apply dropout </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-11'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-11'>#</a>
|
||||||
|
</div>
|
||||||
|
<p> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-12'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-12'>#</a>
|
||||||
|
</div>
|
||||||
|
<h2>Auto-Regressive model</h2>
|
||||||
|
<p>This is an autoregressive transformer model that uses Feed-Forward Networks with (Fuzzy Tiling Activations)(index.html).</p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">69</span><span class="k">class</span> <span class="nc">AutoregressiveTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-13'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-13'>#</a>
|
||||||
|
</div>
|
||||||
|
<ul><li><code>n_tokens</code> is the number of tokens in the vocabulary </li>
|
||||||
|
<li><code>d_model</code> is the embedding size </li>
|
||||||
|
<li><code>n_layers</code> is the number of transformer layers </li>
|
||||||
|
<li><code>layer</code> is the layer. We use <code class="highlight"><span></span><span class="n">n_layers</span></code>
|
||||||
|
copies of this for the transformer.</li></ul>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-14'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-14'>#</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">84</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-15'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-15'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Transformer with <code class="highlight"><span></span><span class="n">n_layers</span></code>
|
||||||
|
layers </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">86</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span><span class="p">)])</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-16'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-16'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Token embedding layer </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">89</span> <span class="bp">self</span><span class="o">.</span><span class="n">emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-17'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-17'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Readout layer </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">91</span> <span class="bp">self</span><span class="o">.</span><span class="n">readout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_tokens</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-18'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-18'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>The mask will be initialized on the first call </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-19'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-19'>#</a>
|
||||||
|
</div>
|
||||||
|
<ul><li><code>x</code> are the input tokens of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
|
||||||
|
</li></ul>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">96</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-20'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-20'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Create auto-regressive mask </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-21'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-21'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Subsequent mask, will mask out tokens from seeing future tokens </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="n">subsequent_mask</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-22'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-22'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Get the token embeddings </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">106</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-23'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-23'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Transformer encoder </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span><span class="p">:</span>
|
||||||
|
<span class="lineno">109</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mask</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-24'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-24'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Get logits </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">readout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-25'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-25'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Return results </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="kc">None</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-26'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-26'>#</a>
|
||||||
|
</div>
|
||||||
|
<h2>Configurations</h2>
|
||||||
|
<p>This inherits from <a href="../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs"><code class="highlight"><span></span><span class="n">NLPAutoRegressionConfigs</span></code>
|
||||||
|
</a></p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">117</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-27'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-27'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Model </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">model</span><span class="p">:</span> <span class="n">AutoregressiveTransformer</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-28'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-28'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Number of layers </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-29'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-29'>#</a>
|
||||||
|
</div>
|
||||||
|
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span> and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span></span></span></span> for DeepNorm </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">deep_norm_alpha</span><span class="p">:</span> <span class="nb">float</span>
|
||||||
|
<span class="lineno">133</span> <span class="n">deep_norm_beta</span><span class="p">:</span> <span class="nb">float</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-30'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-30'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Number of heads in the attention </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-31'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-31'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Embedding size </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-32'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-32'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Size of each attention head </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-33'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-33'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Feed forward layer size </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-34'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-34'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>FTA </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">fta_lower_limit</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.</span>
|
||||||
|
<span class="lineno">146</span> <span class="n">fta_upper_limit</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">+</span><span class="mf">1.</span>
|
||||||
|
<span class="lineno">147</span> <span class="n">fta_delta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span>
|
||||||
|
<span class="lineno">148</span> <span class="n">fta_eta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.05</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-35'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-35'>#</a>
|
||||||
|
</div>
|
||||||
|
<h4>Initialize the model</h4>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">151</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||||||
|
<span class="lineno">152</span><span class="k">def</span> <span class="nf">_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-36'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-36'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Create FTA activation module </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">fta</span> <span class="o">=</span> <span class="n">FTA</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">fta_lower_limit</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">fta_upper_limit</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">fta_delta</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">fta_eta</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-37'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-37'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Create the transformer. We re-use <a href="../../transformers/models.html#TransformerLayer"><code class="highlight"><span></span><span class="n">TransformerLayer</span></code>
|
||||||
|
</a> and <a href="../../transformers/mha.html"><code class="highlight"><span></span><span class="n">MultiHeadAttention</span></code>
|
||||||
|
</a> implementations. </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">m</span> <span class="o">=</span> <span class="n">AutoregressiveTransformer</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">,</span>
|
||||||
|
<span class="lineno">163</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
|
||||||
|
<span class="lineno">164</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">FeedForwardFTA</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
|
||||||
|
<span class="lineno">165</span> <span class="n">d_ff</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span>
|
||||||
|
<span class="lineno">166</span> <span class="n">activation</span><span class="o">=</span><span class="n">fta</span><span class="p">,</span>
|
||||||
|
<span class="lineno">167</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span>
|
||||||
|
<span class="lineno">168</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
|
||||||
|
<span class="lineno">169</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="mf">0.0</span><span class="p">),</span>
|
||||||
|
<span class="lineno">170</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="mf">0.0</span><span class="p">))</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-38'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-38'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Move to the device </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">return</span> <span class="n">m</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-39'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-39'>#</a>
|
||||||
|
</div>
|
||||||
|
<h4>Create and run the experiment</h4>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">176</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-40'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-40'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Create experiment </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"fta"</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">'screen'</span><span class="p">,</span> <span class="s1">'comet'</span><span class="p">,</span> <span class="s1">'labml'</span><span class="p">})</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-41'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-41'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Create configs </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-42'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-42'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Override configurations </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">185</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-43'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-43'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Use character level tokenizer </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">187</span> <span class="s1">'tokenizer'</span><span class="p">:</span> <span class="s1">'character'</span><span class="p">,</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-44'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-44'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Prompt separator is blank </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">189</span> <span class="s1">'prompt_separator'</span><span class="p">:</span> <span class="s1">''</span><span class="p">,</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-45'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-45'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Starting prompt for sampling </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">191</span> <span class="s1">'prompt'</span><span class="p">:</span> <span class="s1">'It is '</span><span class="p">,</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-46'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-46'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Use Tiny Shakespeare dataset </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">193</span> <span class="s1">'text'</span><span class="p">:</span> <span class="s1">'tiny_shakespeare'</span><span class="p">,</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-47'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-47'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Use a context size of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">256</span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">196</span> <span class="s1">'seq_len'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-48'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-48'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Train for 32 epochs </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">198</span> <span class="s1">'epochs'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-49'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-49'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Batch size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">16</span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">200</span> <span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-50'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-50'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Switch between training and validation for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">10</span></span></span></span> times per epoch </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">202</span> <span class="s1">'inner_iterations'</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-51'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-51'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Adam optimizer with no warmup </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">205</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span>
|
||||||
|
<span class="lineno">206</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">3e-4</span><span class="p">,</span>
|
||||||
|
<span class="lineno">207</span> <span class="p">})</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-52'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-52'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Set model(s) for saving and loading </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">'model'</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-53'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-53'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Start the experiment </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">213</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-54'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-54'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Run training </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">215</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-55'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-55'>#</a>
|
||||||
|
</div>
|
||||||
|
<p> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">219</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||||
|
<span class="lineno">220</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='footer'>
|
||||||
|
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||||
|
<a href="https://labml.ai">labml.ai</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<script src=../../interactive.js?v=1"></script>
|
||||||
|
<script>
|
||||||
|
function handleImages() {
|
||||||
|
var images = document.querySelectorAll('p>img')
|
||||||
|
|
||||||
|
for (var i = 0; i < images.length; ++i) {
|
||||||
|
handleImage(images[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleImage(img) {
|
||||||
|
img.parentElement.style.textAlign = 'center'
|
||||||
|
|
||||||
|
var modal = document.createElement('div')
|
||||||
|
modal.id = 'modal'
|
||||||
|
|
||||||
|
var modalContent = document.createElement('div')
|
||||||
|
modal.appendChild(modalContent)
|
||||||
|
|
||||||
|
var modalImage = document.createElement('img')
|
||||||
|
modalContent.appendChild(modalImage)
|
||||||
|
|
||||||
|
var span = document.createElement('span')
|
||||||
|
span.classList.add('close')
|
||||||
|
span.textContent = 'x'
|
||||||
|
modal.appendChild(span)
|
||||||
|
|
||||||
|
img.onclick = function () {
|
||||||
|
console.log('clicked')
|
||||||
|
document.body.appendChild(modal)
|
||||||
|
modalImage.src = img.src
|
||||||
|
}
|
||||||
|
|
||||||
|
span.onclick = function () {
|
||||||
|
document.body.removeChild(modal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handleImages()
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
402
docs/activations/fta/index.html
Normal file
402
docs/activations/fta/index.html
Normal file
@ -0,0 +1,402 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||||
|
<meta name="description" content="PyTorch implementation and tutorial of Fuzzy Tiling Activations from the paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online."/>
|
||||||
|
|
||||||
|
<meta name="twitter:card" content="summary"/>
|
||||||
|
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||||
|
<meta name="twitter:title" content="Fuzzy Tiling Activations"/>
|
||||||
|
<meta name="twitter:description" content="PyTorch implementation and tutorial of Fuzzy Tiling Activations from the paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online."/>
|
||||||
|
<meta name="twitter:site" content="@labmlai"/>
|
||||||
|
<meta name="twitter:creator" content="@labmlai"/>
|
||||||
|
|
||||||
|
<meta property="og:url" content="https://nn.labml.ai/activations/fta/index.html"/>
|
||||||
|
<meta property="og:title" content="Fuzzy Tiling Activations"/>
|
||||||
|
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||||
|
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||||
|
<meta property="og:type" content="object"/>
|
||||||
|
<meta property="og:title" content="Fuzzy Tiling Activations"/>
|
||||||
|
<meta property="og:description" content="PyTorch implementation and tutorial of Fuzzy Tiling Activations from the paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online."/>
|
||||||
|
|
||||||
|
<title>Fuzzy Tiling Activations</title>
|
||||||
|
<link rel="shortcut icon" href="/icon.png"/>
|
||||||
|
<link rel="stylesheet" href="../../pylit.css?v=1">
|
||||||
|
<link rel="canonical" href="https://nn.labml.ai/activations/fta/index.html"/>
|
||||||
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||||
|
|
||||||
|
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||||
|
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||||
|
<script>
|
||||||
|
window.dataLayer = window.dataLayer || [];
|
||||||
|
|
||||||
|
function gtag() {
|
||||||
|
dataLayer.push(arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
gtag('js', new Date());
|
||||||
|
|
||||||
|
gtag('config', 'G-4V3HC8HBLH');
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id='container'>
|
||||||
|
<div id="background"></div>
|
||||||
|
<div class='section'>
|
||||||
|
<div class='docs'>
|
||||||
|
<p>
|
||||||
|
<a class="parent" href="/">home</a>
|
||||||
|
<a class="parent" href="../index.html">activations</a>
|
||||||
|
<a class="parent" href="index.html">fta</a>
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
|
||||||
|
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/activations/fta/__init__.py">
|
||||||
|
<img alt="Github"
|
||||||
|
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||||
|
style="max-width:100%;"/></a>
|
||||||
|
<a href="https://twitter.com/labmlai"
|
||||||
|
rel="nofollow">
|
||||||
|
<img alt="Twitter"
|
||||||
|
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||||
|
style="max-width:100%;"/></a>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-0'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-0'>#</a>
|
||||||
|
</div>
|
||||||
|
<h1>Fuzzy Tiling Activations (FTA)</h1>
|
||||||
|
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of <a href="https://papers.labml.ai/paper/aca66d8edc8911eba3db37f65e372566">Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online</a>.</p>
|
||||||
|
<p>Fuzzy tiling activations are a form of sparse activations based on binning.</p>
|
||||||
|
<p>Binning is classification of a scalar value into a bin based on intervals. One problem with binning is that it gives zero gradients for most values (except at the boundary of bins). The other is that binning loses precision if the bin intervals are large.</p>
|
||||||
|
<p>FTA overcomes these disadvantages. Instead of hard boundaries like in Tiling Activations, FTA uses soft boundaries between bins. This gives non-zero gradients for all or a wide range of values. And also doesn't lose precision since it's captured in partial values.</p>
|
||||||
|
<h4>Tiling Activations</h4>
|
||||||
|
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44444em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span></span></span></span> is the tiling vector,</p>
|
||||||
|
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44444em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner">…</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mclose">)</span></span></span></span></span></p>
|
||||||
|
<p>where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mclose">]</span></span></span></span> is the input range, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span> is the bin size, and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqm" style=""><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span></span></span></span></span> is divisible by <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span>.</p>
|
||||||
|
<p>Tiling activation is,</p>
|
||||||
|
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">ϕ</span><span class="mopen">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mclose">)</span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span></p>
|
||||||
|
<p>where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord">⋅</span><span class="mclose">)</span></span></span></span> is the indicator function which gives <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span></span></span></span> if the input is positive and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span></span></span></span> otherwise.</p>
|
||||||
|
<p>Note that tiling activation gives zero gradients because it has hard boundaries.</p>
|
||||||
|
<h4>Fuzzy Tiling Activations</h4>
|
||||||
|
<p>The fuzzy indicator function,</p>
|
||||||
|
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">η</span><span class="mpunct mtight">,</span><span class="mord mtight">+</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mclose">)</span></span></span></span></span></p>
|
||||||
|
<p>which increases linearly from <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span></span></span></span> to <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span></span></span></span> when <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.78041em;vertical-align:-0.13597em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">≤</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.5782em;vertical-align:-0.0391em;"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"><</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> and is equal to <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span></span></span></span> for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8304100000000001em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">≤</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">x</span></span></span></span>. <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> is a hyper-parameter.</p>
|
||||||
|
<p>FTA uses this to create soft boundaries between bins.</p>
|
||||||
|
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="">ϕ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">η</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">η</span><span class="mpunct mtight">,</span><span class="mord mtight">+</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span></p>
|
||||||
|
<p><a href="experiment.html">Here's a simple experiment</a> that uses FTA in a transformer.</p>
|
||||||
|
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step"><img alt="Open In Comet" src="https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model"></a></p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">62</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||||
|
<span class="lineno">63</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-1'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-1'>#</a>
|
||||||
|
</div>
|
||||||
|
<h3>Fuzzy Tiling Activations (FTA)</h3>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">66</span><span class="k">class</span> <span class="nc">FTA</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-2'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-2'>#</a>
|
||||||
|
</div>
|
||||||
|
<ul><li><code>lower_limit</code> is the lower limit <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span></span></span></span> </li>
|
||||||
|
<li><code>upper_limit</code> is the upper limit <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span></span></span></span> </li>
|
||||||
|
<li><code>delta</code> is the bin size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span> </li>
|
||||||
|
<li><code>eta</code> is the parameter <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> that detemines the softness of the boundaries.</li></ul>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lower_limit</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">upper_limit</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">delta</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">eta</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-3'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-3'>#</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">78</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-4'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-4'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Initialize tiling vector <span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44444em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqq" style=""><span class="mord mathnormal" style="margin-right:0.01968em">l</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner">…</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right:0.03785em;">δ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="">u</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">c</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">lower_limit</span><span class="p">,</span> <span class="n">upper_limit</span><span class="p">,</span> <span class="n">delta</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-5'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-5'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>The input vector expands by a factor equal to the number of bins <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="margin-right:0.03785em">δ</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqm" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="">u</span></span><span class="mbin mtight" style="">−</span><span class="mord mtight coloredeq eqq" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">expansion_factor</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">c</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-6'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-6'>#</a>
|
||||||
|
</div>
|
||||||
|
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">delta</span> <span class="o">=</span> <span class="n">delta</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-7'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-7'>#</a>
|
||||||
|
</div>
|
||||||
|
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">87</span> <span class="bp">self</span><span class="o">.</span><span class="n">eta</span> <span class="o">=</span> <span class="n">eta</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-8'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-8'>#</a>
|
||||||
|
</div>
|
||||||
|
<h4>Fuzzy indicator function</h4>
|
||||||
|
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">η</span><span class="mpunct mtight">,</span><span class="mord mtight">+</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.03588em">η</span></span><span class="mclose">)</span></span></span></span></span></p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">89</span> <span class="k">def</span> <span class="nf">fuzzy_i_plus</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-9'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-9'>#</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">eta</span><span class="p">)</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="p">(</span><span class="n">x</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">eta</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-10'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-10'>#</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">97</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-11'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-11'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Add another dimension of size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span></span></span></span>. We will expand this into bins. </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">z</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-12'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-12'>#</a>
|
||||||
|
</div>
|
||||||
|
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="">ϕ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">η</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.25833100000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">η</span><span class="mpunct mtight">,</span><span class="mord mtight">+</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.77777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="margin-right:0.03785em">δ</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style="">0</span></span><span class="mclose">)</span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">z</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">fuzzy_i_plus</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">c</span> <span class="o">-</span> <span class="n">z</span><span class="p">,</span> <span class="nb">min</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">z</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">delta</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">c</span><span class="p">,</span> <span class="nb">min</span><span class="o">=</span><span class="mf">0.</span><span class="p">))</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-13'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-13'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Reshape back to original number of dimensions. The last dimension size gets expanded by the number of bins, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="margin-right:0.03785em">δ</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqm" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="">u</span></span><span class="mbin mtight" style="">−</span><span class="mord mtight coloredeq eqq" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span>. </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">return</span> <span class="n">z</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-14'>
|
||||||
|
<div class='docs doc-strings'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-14'>#</a>
|
||||||
|
</div>
|
||||||
|
<h4>Code to test the FTA module</h4>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">110</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-15'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-15'>#</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">114</span> <span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-16'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-16'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Initialize </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">a</span> <span class="o">=</span> <span class="n">FTA</span><span class="p">(</span><span class="o">-</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mf">2.</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-17'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-17'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Print <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44444em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathbf" style="">c</span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">inspect</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">c</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-18'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-18'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Print number of bins <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="margin-right:0.03785em">δ</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqm" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="">u</span></span><span class="mbin mtight" style="">−</span><span class="mord mtight coloredeq eqq" style=""><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">inspect</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">expansion_factor</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-19'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-19'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Input <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.1</span><span class="p">,</span> <span class="mf">2.2</span><span class="p">,</span> <span class="mf">3.3</span><span class="p">,</span> <span class="mf">4.4</span><span class="p">,</span> <span class="mf">5.5</span><span class="p">,</span> <span class="mf">6.6</span><span class="p">,</span> <span class="mf">7.7</span><span class="p">,</span> <span class="mf">8.8</span><span class="p">,</span> <span class="mf">9.</span><span class="p">,</span> <span class="mf">10.</span><span class="p">,</span> <span class="mf">11.</span><span class="p">])</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-20'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-20'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Print <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">inspect</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='section' id='section-21'>
|
||||||
|
<div class='docs'>
|
||||||
|
<div class='section-link'>
|
||||||
|
<a href='#section-21'>#</a>
|
||||||
|
</div>
|
||||||
|
<p>Print <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="">ϕ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">η</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq eqs" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span><span class="mclose" style="">)</span></span></span></span></span> </p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
<div class='code'>
|
||||||
|
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">inspect</span><span class="p">(</span><span class="n">a</span><span class="p">(</span><span class="n">z</span><span class="p">))</span>
|
||||||
|
<span class="lineno">129</span>
|
||||||
|
<span class="lineno">130</span>
|
||||||
|
<span class="lineno">131</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||||
|
<span class="lineno">132</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class='footer'>
|
||||||
|
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||||
|
<a href="https://labml.ai">labml.ai</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<script src=../../interactive.js?v=1"></script>
|
||||||
|
<script>
|
||||||
|
function handleImages() {
|
||||||
|
var images = document.querySelectorAll('p>img')
|
||||||
|
|
||||||
|
for (var i = 0; i < images.length; ++i) {
|
||||||
|
handleImage(images[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleImage(img) {
|
||||||
|
img.parentElement.style.textAlign = 'center'
|
||||||
|
|
||||||
|
var modal = document.createElement('div')
|
||||||
|
modal.id = 'modal'
|
||||||
|
|
||||||
|
var modalContent = document.createElement('div')
|
||||||
|
modal.appendChild(modalContent)
|
||||||
|
|
||||||
|
var modalImage = document.createElement('img')
|
||||||
|
modalContent.appendChild(modalImage)
|
||||||
|
|
||||||
|
var span = document.createElement('span')
|
||||||
|
span.classList.add('close')
|
||||||
|
span.textContent = 'x'
|
||||||
|
modal.appendChild(span)
|
||||||
|
|
||||||
|
img.onclick = function () {
|
||||||
|
console.log('clicked')
|
||||||
|
document.body.appendChild(modal)
|
||||||
|
modalImage.src = img.src
|
||||||
|
}
|
||||||
|
|
||||||
|
span.onclick = function () {
|
||||||
|
document.body.removeChild(modal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handleImages()
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
@ -3,24 +3,24 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||||
<meta name="description" content=""/>
|
<meta name="description" content="A set of PyTorch implementations/tutorials related to neural network activations"/>
|
||||||
|
|
||||||
<meta name="twitter:card" content="summary"/>
|
<meta name="twitter:card" content="summary"/>
|
||||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||||
<meta name="twitter:title" content="__init__.py"/>
|
<meta name="twitter:title" content="Neural Network Activation Functions"/>
|
||||||
<meta name="twitter:description" content=""/>
|
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials related to neural network activations"/>
|
||||||
<meta name="twitter:site" content="@labmlai"/>
|
<meta name="twitter:site" content="@labmlai"/>
|
||||||
<meta name="twitter:creator" content="@labmlai"/>
|
<meta name="twitter:creator" content="@labmlai"/>
|
||||||
|
|
||||||
<meta property="og:url" content="https://nn.labml.ai/activations/index.html"/>
|
<meta property="og:url" content="https://nn.labml.ai/activations/index.html"/>
|
||||||
<meta property="og:title" content="__init__.py"/>
|
<meta property="og:title" content="Neural Network Activation Functions"/>
|
||||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||||
<meta property="og:type" content="object"/>
|
<meta property="og:type" content="object"/>
|
||||||
<meta property="og:title" content="__init__.py"/>
|
<meta property="og:title" content="Neural Network Activation Functions"/>
|
||||||
<meta property="og:description" content=""/>
|
<meta property="og:description" content="A set of PyTorch implementations/tutorials related to neural network activations"/>
|
||||||
|
|
||||||
<title>__init__.py</title>
|
<title>Neural Network Activation Functions</title>
|
||||||
<link rel="shortcut icon" href="/icon.png"/>
|
<link rel="shortcut icon" href="/icon.png"/>
|
||||||
<link rel="stylesheet" href="../pylit.css?v=1">
|
<link rel="stylesheet" href="../pylit.css?v=1">
|
||||||
<link rel="canonical" href="https://nn.labml.ai/activations/index.html"/>
|
<link rel="canonical" href="https://nn.labml.ai/activations/index.html"/>
|
||||||
@ -64,14 +64,17 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-0'>
|
<div class='section' id='section-0'>
|
||||||
<div class='docs'>
|
<div class='docs doc-strings'>
|
||||||
<div class='section-link'>
|
<div class='section-link'>
|
||||||
<a href='#section-0'>#</a>
|
<a href='#section-0'>#</a>
|
||||||
</div>
|
</div>
|
||||||
|
<h1>Neural Networks Activations</h1>
|
||||||
|
<ul><li><a href="fta/index.html">Fuzzy Tiling Activations</a> </li>
|
||||||
|
<li>🚧 <a href="swish/index.html">Swish</a></li></ul>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">from</span> <span class="nn">.swish</span> <span class="kn">import</span> <span class="n">Swish</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">from</span> <span class="nn">.swish</span> <span class="kn">import</span> <span class="n">Swish</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='footer'>
|
<div class='footer'>
|
||||||
|
@ -139,6 +139,8 @@
|
|||||||
<ul><li><a href="adaptive_computation/ponder_net/index.html">PonderNet</a></li></ul>
|
<ul><li><a href="adaptive_computation/ponder_net/index.html">PonderNet</a></li></ul>
|
||||||
<h4>✨ <a href="uncertainty/index.html">Uncertainty</a></h4>
|
<h4>✨ <a href="uncertainty/index.html">Uncertainty</a></h4>
|
||||||
<ul><li><a href="uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li></ul>
|
<ul><li><a href="uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li></ul>
|
||||||
|
<h4>✨ <a href="activations/index.html">Activations</a></h4>
|
||||||
|
<ul><li><a href="activations/fta/index.html">Fuzzy Tiling Activations</a></li></ul>
|
||||||
<h2>Highlighted Research Paper PDFs</h2>
|
<h2>Highlighted Research Paper PDFs</h2>
|
||||||
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
|
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
|
||||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
|
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
|
||||||
|
@ -69,9 +69,8 @@
|
|||||||
<div class='section-link'>
|
<div class='section-link'>
|
||||||
<a href='#section-0'>#</a>
|
<a href='#section-0'>#</a>
|
||||||
</div>
|
</div>
|
||||||
<h1>DeepNorm Experiment</h1>
|
<h1><a href="index.html">DeepNorm</a> Experiment</h1>
|
||||||
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a> <a href="https://www.comet.ml/labml/deep-norm/61d817f80ff143c8825fba4aacd431d4?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=step"><img alt="Open In Comet" src="https://images.labml.ai/images/comet.svg?experiment=deep_norm&file=experiment"></a></p>
|
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">copy</span>
|
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">copy</span>
|
||||||
|
@ -85,7 +85,21 @@
|
|||||||
|
|
||||||
<url>
|
<url>
|
||||||
<loc>https://nn.labml.ai/activations/index.html</loc>
|
<loc>https://nn.labml.ai/activations/index.html</loc>
|
||||||
<lastmod>2021-01-25T16:30:00+00:00</lastmod>
|
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
|
||||||
|
<priority>1.00</priority>
|
||||||
|
</url>
|
||||||
|
|
||||||
|
|
||||||
|
<url>
|
||||||
|
<loc>https://nn.labml.ai/activations/fta/index.html</loc>
|
||||||
|
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
|
||||||
|
<priority>1.00</priority>
|
||||||
|
</url>
|
||||||
|
|
||||||
|
|
||||||
|
<url>
|
||||||
|
<loc>https://nn.labml.ai/activations/fta/experiment.html</loc>
|
||||||
|
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
|
||||||
<priority>1.00</priority>
|
<priority>1.00</priority>
|
||||||
</url>
|
</url>
|
||||||
|
|
||||||
@ -197,7 +211,7 @@
|
|||||||
|
|
||||||
<url>
|
<url>
|
||||||
<loc>https://nn.labml.ai/normalization/deep_norm/experiment.html</loc>
|
<loc>https://nn.labml.ai/normalization/deep_norm/experiment.html</loc>
|
||||||
<lastmod>2022-04-23T16:30:00+00:00</lastmod>
|
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
|
||||||
<priority>1.00</priority>
|
<priority>1.00</priority>
|
||||||
</url>
|
</url>
|
||||||
|
|
||||||
@ -295,7 +309,7 @@
|
|||||||
|
|
||||||
<url>
|
<url>
|
||||||
<loc>https://nn.labml.ai/index.html</loc>
|
<loc>https://nn.labml.ai/index.html</loc>
|
||||||
<lastmod>2022-05-03T16:30:00+00:00</lastmod>
|
<lastmod>2022-05-23T16:30:00+00:00</lastmod>
|
||||||
<priority>1.00</priority>
|
<priority>1.00</priority>
|
||||||
</url>
|
</url>
|
||||||
|
|
||||||
@ -589,7 +603,7 @@
|
|||||||
|
|
||||||
<url>
|
<url>
|
||||||
<loc>https://nn.labml.ai/transformers/rope/index.html</loc>
|
<loc>https://nn.labml.ai/transformers/rope/index.html</loc>
|
||||||
<lastmod>2022-02-23T16:30:00+00:00</lastmod>
|
<lastmod>2022-04-05T16:30:00+00:00</lastmod>
|
||||||
<priority>1.00</priority>
|
<priority>1.00</priority>
|
||||||
</url>
|
</url>
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
@ -112,6 +112,10 @@ Solving games with incomplete information such as poker with CFR.
|
|||||||
|
|
||||||
* [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html)
|
* [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html)
|
||||||
|
|
||||||
|
#### ✨ [Activations](activations/index.html)
|
||||||
|
|
||||||
|
* [Fuzzy Tiling Activations](activations/fta/index.html)
|
||||||
|
|
||||||
## Highlighted Research Paper PDFs
|
## Highlighted Research Paper PDFs
|
||||||
|
|
||||||
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
|
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
|
||||||
|
@ -1 +1,14 @@
|
|||||||
|
"""
|
||||||
|
---
|
||||||
|
title: Neural Network Activation Functions
|
||||||
|
summary: >
|
||||||
|
A set of PyTorch implementations/tutorials related to neural network activations
|
||||||
|
---
|
||||||
|
|
||||||
|
# Neural Networks Activations
|
||||||
|
|
||||||
|
* [Fuzzy Tiling Activations](fta/index.html)
|
||||||
|
* 🚧 [Swish](swish/index.html)
|
||||||
|
"""
|
||||||
|
|
||||||
from .swish import Swish
|
from .swish import Swish
|
||||||
|
132
labml_nn/activations/fta/__init__.py
Normal file
132
labml_nn/activations/fta/__init__.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
---
|
||||||
|
title: Fuzzy Tiling Activations
|
||||||
|
summary: >
|
||||||
|
PyTorch implementation and tutorial of Fuzzy Tiling Activations from the
|
||||||
|
paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Fuzzy Tiling Activations (FTA)
|
||||||
|
|
||||||
|
This is a [PyTorch](https://pytorch.org) implementation/tutorial of
|
||||||
|
[Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online](https://papers.labml.ai/paper/aca66d8edc8911eba3db37f65e372566).
|
||||||
|
|
||||||
|
Fuzzy tiling activations are a form of sparse activations based on binning.
|
||||||
|
|
||||||
|
Binning is classification of a scalar value into a bin based on intervals.
|
||||||
|
One problem with binning is that it gives zero gradients for most values (except at the boundary of bins).
|
||||||
|
The other is that binning loses precision if the bin intervals are large.
|
||||||
|
|
||||||
|
FTA overcomes these disadvantages.
|
||||||
|
Instead of hard boundaries like in Tiling Activations, FTA uses soft boundaries
|
||||||
|
between bins.
|
||||||
|
This gives non-zero gradients for all or a wide range of values.
|
||||||
|
And also doesn't lose precision since it's captured in partial values.
|
||||||
|
|
||||||
|
#### Tiling Activations
|
||||||
|
|
||||||
|
$\mathbf{c}$ is the tiling vector,
|
||||||
|
|
||||||
|
$$\mathbf{c} = (l, l + \delta, l + 2 \delta, \dots, u - 2 \delta, u - \delta)$$
|
||||||
|
|
||||||
|
where $[l, u]$ is the input range, $\delta$ is the bin size, and $u - l$ is divisible by $\delta$.
|
||||||
|
|
||||||
|
Tiling activation is,
|
||||||
|
|
||||||
|
$$\phi(z) = 1 - I_+ \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}) \big)$$
|
||||||
|
|
||||||
|
where $I_+(\cdot)$ is the indicator function which gives $1$ if the input is positive and $0$ otherwise.
|
||||||
|
|
||||||
|
Note that tiling activation gives zero gradients because it has hard boundaries.
|
||||||
|
|
||||||
|
#### Fuzzy Tiling Activations
|
||||||
|
|
||||||
|
The fuzzy indicator function,
|
||||||
|
|
||||||
|
$$I_{\eta,+}(x) = I_+(\eta - x) x + I_+ (x - \eta)$$
|
||||||
|
|
||||||
|
which increases linearly from $0$ to $1$ when $0 \le x \lt \eta$
|
||||||
|
and is equal to $1$ for $\eta \le x$.
|
||||||
|
$\eta$ is a hyper-parameter.
|
||||||
|
|
||||||
|
FTA uses this to create soft boundaries between bins.
|
||||||
|
|
||||||
|
$$\phi_\eta(z) = 1 - I_{\eta,+} \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}, 0) \big)$$
|
||||||
|
|
||||||
|
[Here's a simple experiment](experiment.html) that uses FTA in a transformer.
|
||||||
|
|
||||||
|
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)
|
||||||
|
[](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class FTA(nn.Module):
|
||||||
|
"""
|
||||||
|
### Fuzzy Tiling Activations (FTA)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lower_limit: float, upper_limit: float, delta: float, eta: float):
|
||||||
|
"""
|
||||||
|
:param lower_limit: is the lower limit $l$
|
||||||
|
:param upper_limit: is the upper limit $u$
|
||||||
|
:param delta: is the bin size $\delta$
|
||||||
|
:param eta: is the parameter $\eta$ that detemines the softness of the boundaries.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# Initialize tiling vector
|
||||||
|
# $$\mathbf{c} = (l, l + \delta, l + 2 \delta, \dots, u - 2 \delta, u - \delta)$$
|
||||||
|
self.c = nn.Parameter(torch.arange(lower_limit, upper_limit, delta), requires_grad=False)
|
||||||
|
# The input vector expands by a factor equal to the number of bins $\frac{u - l}{\delta}$
|
||||||
|
self.expansion_factor = len(self.c)
|
||||||
|
# $\delta$
|
||||||
|
self.delta = delta
|
||||||
|
# $\eta$
|
||||||
|
self.eta = eta
|
||||||
|
|
||||||
|
def fuzzy_i_plus(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
#### Fuzzy indicator function
|
||||||
|
|
||||||
|
$$I_{\eta,+}(x) = I_+(\eta - x) x + I_+ (x - \eta)$$
|
||||||
|
"""
|
||||||
|
return (x <= self.eta) * x + (x > self.eta)
|
||||||
|
|
||||||
|
def forward(self, z: torch.Tensor):
|
||||||
|
# Add another dimension of size $1$.
|
||||||
|
# We will expand this into bins.
|
||||||
|
z = z.view(*z.shape, 1)
|
||||||
|
|
||||||
|
# $$\phi_\eta(z) = 1 - I_{\eta,+} \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}, 0) \big)$$
|
||||||
|
z = 1. - self.fuzzy_i_plus(torch.clip(self.c - z, min=0.) + torch.clip(z - self.delta - self.c, min=0.))
|
||||||
|
|
||||||
|
# Reshape back to original number of dimensions.
|
||||||
|
# The last dimension size gets expanded by the number of bins, $\frac{u - l}{\delta}$.
|
||||||
|
return z.view(*z.shape[:-2], -1)
|
||||||
|
|
||||||
|
|
||||||
|
def _test():
|
||||||
|
"""
|
||||||
|
#### Code to test the FTA module
|
||||||
|
"""
|
||||||
|
from labml.logger import inspect
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
a = FTA(-10, 10, 2., 0.5)
|
||||||
|
# Print $\mathbf{c}$
|
||||||
|
inspect(a.c)
|
||||||
|
# Print number of bins $\frac{u - l}{\delta}$
|
||||||
|
inspect(a.expansion_factor)
|
||||||
|
|
||||||
|
# Input $z$
|
||||||
|
z = torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9., 10., 11.])
|
||||||
|
# Print $z$
|
||||||
|
inspect(z)
|
||||||
|
# Print $\phi_\eta(z)$
|
||||||
|
inspect(a(z))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_test()
|
299
labml_nn/activations/fta/experiment.ipynb
Normal file
299
labml_nn/activations/fta/experiment.ipynb
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "AYV_dMVDxyc2",
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"[](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n",
|
||||||
|
"[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)\n",
|
||||||
|
"[](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)\n",
|
||||||
|
"\n",
|
||||||
|
"## [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html)\n",
|
||||||
|
"\n",
|
||||||
|
"Here we train a transformer that uses [Fuzzy Tiling Activation](https://nn.labml.ai/activations/fta/index.html) in the\n",
|
||||||
|
"[Feed-Forward Network](https://nn.labml.ai/transformers/feed_forward.html).\n",
|
||||||
|
"We use it for a language model and train it on Tiny Shakespeare dataset\n",
|
||||||
|
"for demonstration.\n",
|
||||||
|
"However, this is probably not the ideal task for FTA, and we\n",
|
||||||
|
"believe FTA is more suitable for modeling data with continuous variables."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "AahG_i2y5tY9",
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Install the packages"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "ZCzmCrAIVg0L",
|
||||||
|
"outputId": "cf107fb2-4d50-4c67-af34-367624553421",
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install labml-nn comet_ml --quiet"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Enable [Comet](https://www.comet.ml)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#@markdown Select in order to enable logging this experiment to [Comet](https://www.comet.ml).\n",
|
||||||
|
"use_comet = False #@param {type:\"boolean\"}\n",
|
||||||
|
"\n",
|
||||||
|
"if use_comet:\n",
|
||||||
|
" import comet_ml\n",
|
||||||
|
" comet_ml.init(project_name='fta')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "SE2VUQ6L5zxI",
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Imports"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"\n",
|
||||||
|
"from labml import experiment\n",
|
||||||
|
"from labml.configs import option\n",
|
||||||
|
"from labml_nn.activations.fta.experiment import Configs"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Create an experiment"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"experiment.create(name=\"fta\", writers={\"screen\", \"comet\"} if use_comet else {'screen'})"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Configurations"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"conf = Configs()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"Set experiment configurations and assign a configurations dictionary to override configurations"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"experiment.configs(conf, {\n",
|
||||||
|
" 'tokenizer': 'character',\n",
|
||||||
|
" 'prompt_separator': '',\n",
|
||||||
|
" 'prompt': 'It is ',\n",
|
||||||
|
" 'text': 'tiny_shakespeare',\n",
|
||||||
|
"\n",
|
||||||
|
" 'seq_len': 256,\n",
|
||||||
|
" 'epochs': 32,\n",
|
||||||
|
" 'batch_size': 16,\n",
|
||||||
|
" 'inner_iterations': 10,\n",
|
||||||
|
"\n",
|
||||||
|
" 'optimizer.optimizer': 'Adam',\n",
|
||||||
|
" 'optimizer.learning_rate': 3e-4,\n",
|
||||||
|
"})"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "EvI7MtgJ61w5",
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"Set PyTorch models for loading and saving"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/",
|
||||||
|
"height": 255
|
||||||
|
},
|
||||||
|
"id": "GDlt7dp-5ALt",
|
||||||
|
"outputId": "e7548e8f-c541-4618-dc5a-1597cae42003",
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"experiment.add_pytorch_models({'model': conf.model})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "KJZRf8527GxL",
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Start the experiment and run the training loop."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/",
|
||||||
|
"height": 1000
|
||||||
|
},
|
||||||
|
"id": "aIAWo7Fw5DR8",
|
||||||
|
"outputId": "db979785-bfe3-4eda-d3eb-8ccbe61053e5",
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Start the experiment\n",
|
||||||
|
"with experiment.start():\n",
|
||||||
|
" conf.run()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"collapsed_sections": [],
|
||||||
|
"name": "FTA",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.11"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
220
labml_nn/activations/fta/experiment.py
Normal file
220
labml_nn/activations/fta/experiment.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
---
|
||||||
|
title: Fuzzy Tiling Activation Experiment
|
||||||
|
summary: >
|
||||||
|
Training a transformer with FTA in FFN on Tiny Shakespeare.
|
||||||
|
---
|
||||||
|
|
||||||
|
# [Fuzzy Tiling Activation](index.html) Experiment
|
||||||
|
|
||||||
|
Here we train a transformer that uses [Fuzzy Tiling Activation](index.html) in the
|
||||||
|
[Feed-Forward Network](../../transformers/feed_forward.html).
|
||||||
|
We use it for a language model and train it on Tiny Shakespeare dataset
|
||||||
|
for demonstration.
|
||||||
|
|
||||||
|
However, this is probably not the ideal task for FTA, and we
|
||||||
|
believe FTA is more suitable for modeling data with continuous variables.
|
||||||
|
|
||||||
|
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)
|
||||||
|
[](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from labml import experiment
|
||||||
|
from labml.configs import option
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from labml_nn.activations.fta import FTA
|
||||||
|
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||||
|
from labml_nn.transformers import MultiHeadAttention, TransformerLayer
|
||||||
|
from labml_nn.transformers.utils import subsequent_mask
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardFTA(nn.Module):
|
||||||
|
"""
|
||||||
|
## FFN module with [FTA](index.html) activation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model: int, d_ff: int,
|
||||||
|
activation: FTA,
|
||||||
|
dropout: float = 0.1):
|
||||||
|
"""
|
||||||
|
* `d_model` is the number of features in a token embedding
|
||||||
|
* `d_ff` is the number of features in the hidden layer of the FFN
|
||||||
|
* `activation` is FTA activation module
|
||||||
|
* `dropout` is dropout probability for the hidden layer
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# Layer one parameterized by weight $W_1$ and bias $b_1$
|
||||||
|
self.layer1 = nn.Linear(d_model, d_ff)
|
||||||
|
# Layer two parameterized by weight $W_1$ and bias $b_1$
|
||||||
|
self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)
|
||||||
|
# Hidden layer dropout
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
# Activation function $f$
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# $f(x W_1 + b_1)$
|
||||||
|
x = self.activation(self.layer1(x))
|
||||||
|
# Apply dropout
|
||||||
|
x = self.dropout(x)
|
||||||
|
#
|
||||||
|
return self.layer2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoregressiveTransformer(Module):
|
||||||
|
"""
|
||||||
|
## Auto-Regressive model
|
||||||
|
|
||||||
|
This is an autoregressive transformer model that uses Feed-Forward Networks with
|
||||||
|
(Fuzzy Tiling Activations)(index.html).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
|
||||||
|
"""
|
||||||
|
:param n_tokens: is the number of tokens in the vocabulary
|
||||||
|
:param d_model: is the embedding size
|
||||||
|
:param n_layers: is the number of transformer layers
|
||||||
|
:param layer: is the layer. We use `n_layers` copies of this for the transformer.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# Transformer with `n_layers` layers
|
||||||
|
self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
|
||||||
|
|
||||||
|
# Token embedding layer
|
||||||
|
self.emb = nn.Embedding(n_tokens, d_model)
|
||||||
|
# Readout layer
|
||||||
|
self.readout = nn.Linear(d_model, n_tokens)
|
||||||
|
|
||||||
|
# The mask will be initialized on the first call
|
||||||
|
self.mask = None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
:param x: are the input tokens of shape `[seq_len, batch_size]`
|
||||||
|
"""
|
||||||
|
# Create auto-regressive mask
|
||||||
|
if self.mask is None or self.mask.size(0) != len(x):
|
||||||
|
# Subsequent mask, will mask out tokens from seeing future tokens
|
||||||
|
self.mask = subsequent_mask(len(x)).to(x.device)
|
||||||
|
|
||||||
|
# Get the token embeddings
|
||||||
|
x = self.emb(x)
|
||||||
|
# Transformer encoder
|
||||||
|
for layer in self.transformer_layers:
|
||||||
|
x = layer(x=x, mask=self.mask)
|
||||||
|
# Get logits
|
||||||
|
x = self.readout(x)
|
||||||
|
|
||||||
|
# Return results
|
||||||
|
return x, None
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(NLPAutoRegressionConfigs):
|
||||||
|
"""
|
||||||
|
## Configurations
|
||||||
|
|
||||||
|
This inherits from
|
||||||
|
[`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model
|
||||||
|
model: AutoregressiveTransformer
|
||||||
|
|
||||||
|
# Number of layers
|
||||||
|
n_layers: int = 4
|
||||||
|
|
||||||
|
# $\alpha$ and $\beta$ for DeepNorm
|
||||||
|
deep_norm_alpha: float
|
||||||
|
deep_norm_beta: float
|
||||||
|
|
||||||
|
# Number of heads in the attention
|
||||||
|
n_heads: int = 4
|
||||||
|
# Embedding size
|
||||||
|
d_model: int = 256
|
||||||
|
# Size of each attention head
|
||||||
|
d_k: int = 16
|
||||||
|
# Feed forward layer size
|
||||||
|
d_ff: int = 256
|
||||||
|
|
||||||
|
# FTA
|
||||||
|
fta_lower_limit: float = -1.
|
||||||
|
fta_upper_limit: float = +1.
|
||||||
|
fta_delta: float = 0.2
|
||||||
|
fta_eta: float = 0.05
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.model)
|
||||||
|
def _model(c: Configs):
|
||||||
|
"""
|
||||||
|
#### Initialize the model
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create FTA activation module
|
||||||
|
fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)
|
||||||
|
# Create the transformer.
|
||||||
|
# We re-use [`TransformerLayer`](../../transformers/models.html#TransformerLayer) and
|
||||||
|
# [`MultiHeadAttention`](../../transformers/mha.html) implementations.
|
||||||
|
m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
|
||||||
|
TransformerLayer(d_model=c.d_model,
|
||||||
|
feed_forward=FeedForwardFTA(d_model=c.d_model,
|
||||||
|
d_ff=c.d_ff,
|
||||||
|
activation=fta,
|
||||||
|
dropout=0.1),
|
||||||
|
self_attn=MultiHeadAttention(c.n_heads, c.d_model,
|
||||||
|
dropout_prob=0.0),
|
||||||
|
dropout_prob=0.0))
|
||||||
|
|
||||||
|
# Move to the device
|
||||||
|
return m.to(c.device)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
#### Create and run the experiment
|
||||||
|
"""
|
||||||
|
# Create experiment
|
||||||
|
experiment.create(name="fta", writers={'screen', 'comet', 'labml'})
|
||||||
|
# Create configs
|
||||||
|
conf = Configs()
|
||||||
|
# Override configurations
|
||||||
|
experiment.configs(conf, {
|
||||||
|
# Use character level tokenizer
|
||||||
|
'tokenizer': 'character',
|
||||||
|
# Prompt separator is blank
|
||||||
|
'prompt_separator': '',
|
||||||
|
# Starting prompt for sampling
|
||||||
|
'prompt': 'It is ',
|
||||||
|
# Use Tiny Shakespeare dataset
|
||||||
|
'text': 'tiny_shakespeare',
|
||||||
|
|
||||||
|
# Use a context size of $256$
|
||||||
|
'seq_len': 256,
|
||||||
|
# Train for 32 epochs
|
||||||
|
'epochs': 32,
|
||||||
|
# Batch size $16$
|
||||||
|
'batch_size': 16,
|
||||||
|
# Switch between training and validation for $10$ times per epoch
|
||||||
|
'inner_iterations': 10,
|
||||||
|
|
||||||
|
# Adam optimizer with no warmup
|
||||||
|
'optimizer.optimizer': 'Adam',
|
||||||
|
'optimizer.learning_rate': 3e-4,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Set model(s) for saving and loading
|
||||||
|
experiment.add_pytorch_models({'model': conf.model})
|
||||||
|
|
||||||
|
# Start the experiment
|
||||||
|
with experiment.start():
|
||||||
|
# Run training
|
||||||
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -5,7 +5,7 @@ summary: >
|
|||||||
Training a DeepNorm transformer on Tiny Shakespeare.
|
Training a DeepNorm transformer on Tiny Shakespeare.
|
||||||
---
|
---
|
||||||
|
|
||||||
# DeepNorm Experiment
|
# [DeepNorm](index.html) Experiment
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb)
|
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb)
|
||||||
[](https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d)
|
[](https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d)
|
||||||
|
@ -141,7 +141,7 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|||||||
idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
|
idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
|
||||||
|
|
||||||
# Concatenate so that for row $m$ we have
|
# Concatenate so that for row $m$ we have
|
||||||
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta 0, m \theta 1, ..., m \theta_{\frac{d}{2}}]$
|
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
|
||||||
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
||||||
|
|
||||||
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
|
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
|
||||||
|
@ -115,6 +115,11 @@ Solving games with incomplete information such as poker with CFR.
|
|||||||
|
|
||||||
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
|
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
|
||||||
|
|
||||||
|
#### ✨ [Activations](https://nn.labml.ai/activations/index.html)
|
||||||
|
|
||||||
|
* [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html)
|
||||||
|
|
||||||
|
|
||||||
## Highlighted Research Paper PDFs
|
## Highlighted Research Paper PDFs
|
||||||
|
|
||||||
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
|
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
|
||||||
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
|
|||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='labml-nn',
|
name='labml-nn',
|
||||||
version='0.4.121',
|
version='0.4.122',
|
||||||
author="Varuna Jayasiri, Nipun Wijerathne",
|
author="Varuna Jayasiri, Nipun Wijerathne",
|
||||||
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
||||||
description="🧑🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",
|
description="🧑🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",
|
||||||
|
Reference in New Issue
Block a user