Files
Varuna Jayasiri c4d2e8cd22 docs
2025-07-31 08:48:07 +05:30

834 lines
117 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This implements the RWKV model using PyTorch with explanations."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Receptance Weighted Key Value (RWKV)"/>
<meta name="twitter:description" content="This implements the RWKV model using PyTorch with explanations."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/rwkv/index.html"/>
<meta property="og:title" content="Receptance Weighted Key Value (RWKV)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Receptance Weighted Key Value (RWKV)"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Receptance Weighted Key Value (RWKV)"/>
<meta property="og:description" content="This implements the RWKV model using PyTorch with explanations."/>
<title>Receptance Weighted Key Value (RWKV)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/rwkv/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">rwkv</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rwkv/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Receptance Weighted Key Value (RWKV)</h1>
<p>This is a tutorial/implementation of RWKV from paper <a href="https://arxiv.org/pdf/2305.13048.pdf">RWKV: Reinventing RNNs for the Transformer Era</a> in <a href="https://pytorch.org/">PyTorch</a>.</p>
<p>Full definition of a RWKV Language Model, all of it in this single file. References: 1) <a href="https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py">the official RWKV PyTorch implementation released by Bo Peng</a> 2) <a href="https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py">huggingface/transformers PyTorch implementation</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">23</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="lineno">25</span>
<span class="lineno">26</span>
<span class="lineno">27</span><span class="n">PREV_X_TIME</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">28</span><span class="n">NUM_STATE</span> <span class="o">=</span> <span class="mi">1</span>
<span class="lineno">29</span><span class="n">DEN_STATE</span> <span class="o">=</span> <span class="mi">2</span>
<span class="lineno">30</span><span class="n">MAX_STATE</span> <span class="o">=</span> <span class="mi">3</span>
<span class="lineno">31</span><span class="n">PREV_X_CHANNEL</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>Layer normalization with bias</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span><span class="k">class</span> <span class="nc">LayerNorm</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ndim</span><span class="p">,</span> <span class="n">bias</span><span class="p">):</span>
<span class="lineno">40</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">ndim</span><span class="p">))</span>
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">ndim</span><span class="p">))</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">):</span>
<span class="lineno">45</span> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="mf">1e-5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<h3>L2 loss wrapper</h3>
<p><a href="https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21">ref</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span><span class="k">class</span> <span class="nc">L2Wrap</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="nd">@staticmethod</span>
<span class="lineno">56</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="lineno">57</span> <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
<span class="lineno">58</span> <span class="k">return</span> <span class="n">loss</span>
<span class="lineno">59</span>
<span class="lineno">60</span> <span class="nd">@staticmethod</span>
<span class="lineno">61</span> <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_output</span><span class="p">):</span>
<span class="lineno">62</span> <span class="n">y</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>to encourage the logits to be close to 0 </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">factor</span> <span class="o">=</span> <span class="mf">1e-4</span> <span class="o">/</span> <span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="lineno">65</span> <span class="n">maxx</span><span class="p">,</span> <span class="n">ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">66</span> <span class="n">gy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
<span class="lineno">67</span> <span class="n">gy</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">ids</span><span class="p">,</span> <span class="n">maxx</span> <span class="o">*</span> <span class="n">factor</span><span class="p">)</span>
<span class="lineno">68</span> <span class="k">return</span> <span class="n">grad_output</span><span class="p">,</span> <span class="n">gy</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<h3>Channel Mixing</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span><span class="k">class</span> <span class="nc">ChannelMixing</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="n">layer_id</span><span class="p">):</span>
<span class="lineno">77</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ZeroPad2d</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>token shifting </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span> <span class="o">=</span> <span class="n">layer_id</span>
<span class="lineno">81</span>
<span class="lineno">82</span> <span class="n">n_embd</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">n_embd</span>
<span class="lineno">83</span> <span class="n">intermediate_size</span> <span class="o">=</span> <span class="p">(</span>
<span class="lineno">84</span> <span class="n">config</span><span class="o">.</span><span class="n">intermediate_size</span> <span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">intermediate_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">n_embd</span>
<span class="lineno">85</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Learnable Matrix </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">intermediate_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="lineno">89</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">intermediate_size</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="lineno">90</span> <span class="bp">self</span><span class="o">.</span><span class="n">receptance_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Learnable Vector </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_key</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">))</span>
<span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_receptance</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<h1>x = (Batch,Time,Channel)</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">101</span> <span class="n">prev_x</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="p">[</span><span class="n">PREV_X_CHANNEL</span><span class="p">],</span> <span class="p">:]</span>
<span class="lineno">102</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="p">[</span><span class="n">PREV_X_CHANNEL</span><span class="p">],</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">x</span>
<span class="lineno">103</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">104</span> <span class="n">prev_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_shift</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">r</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style=""></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen" style="">(</span><span class="mord" style="">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style=""></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">t</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="n">receptance</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_receptance</span> <span class="o">+</span> <span class="n">prev_x</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_receptance</span><span class="p">)</span>
<span class="lineno">108</span> <span class="n">receptance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">receptance_proj</span><span class="p">(</span><span class="n">receptance</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03148em">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03148em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style=""></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen" style="">(</span><span class="mord" style="">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style=""></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">t</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">key</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_key</span> <span class="o">+</span> <span class="n">prev_x</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_key</span><span class="p">)</span>
<span class="lineno">112</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_proj</span><span class="p">(</span><span class="n">key</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.22222em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.064108em;vertical-align:-0.25em;"></span><span class="mord mathnormal">ma</span><span class="mord mathnormal">x</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03148em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">0</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_proj</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">key</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">o</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="n">out</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">receptance</span><span class="p">)</span> <span class="o">*</span> <span class="n">value</span>
<span class="lineno">119</span> <span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">state</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<h3>Time Mixing</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span><span class="k">class</span> <span class="nc">TimeMixing</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="n">layer_id</span><span class="p">):</span>
<span class="lineno">128</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">129</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
<span class="lineno">130</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_shift</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ZeroPad2d</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="lineno">131</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span> <span class="o">=</span> <span class="n">layer_id</span>
<span class="lineno">132</span>
<span class="lineno">133</span> <span class="n">n_embd</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">n_embd</span>
<span class="lineno">134</span> <span class="n">attn_sz</span> <span class="o">=</span> <span class="n">n_embd</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>learnable matrix </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">attn_sz</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="lineno">138</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">attn_sz</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="lineno">139</span> <span class="bp">self</span><span class="o">.</span><span class="n">receptance_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">attn_sz</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="lineno">140</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">attn_sz</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>learnable vector </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_decay</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">attn_sz</span><span class="p">))</span>
<span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_first</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">attn_sz</span><span class="p">))</span>
<span class="lineno">145</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_key</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">))</span>
<span class="lineno">146</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_value</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">))</span>
<span class="lineno">147</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_receptance</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p> x = (Batch,Time,Channel)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">154</span> <span class="n">prev_x</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="p">[</span><span class="n">PREV_X_TIME</span><span class="p">],</span> <span class="p">:]</span>
<span class="lineno">155</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="p">[</span><span class="n">PREV_X_TIME</span><span class="p">],</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">x</span>
<span class="lineno">156</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">157</span> <span class="n">prev_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_shift</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">r</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style=""></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen" style="">(</span><span class="mord" style="">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style=""></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">t</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">receptance</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_receptance</span> <span class="o">+</span> <span class="n">prev_x</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_receptance</span><span class="p">)</span>
<span class="lineno">161</span> <span class="n">receptance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">receptance_proj</span><span class="p">(</span><span class="n">receptance</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03148em">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03148em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style=""></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen" style="">(</span><span class="mord" style="">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style=""></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord mathnormal" style="">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span><span class="mord" style=""><span class="mord mathnormal" style="">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">t</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose" style="">)</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">key</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_key</span> <span class="o">+</span> <span class="n">prev_x</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_key</span><span class="p">)</span>
<span class="lineno">165</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_proj</span><span class="p">(</span><span class="n">key</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">μ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.301108em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.208331em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">value</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_value</span> <span class="o">+</span> <span class="n">prev_x</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_mix_value</span><span class="p">)</span>
<span class="lineno">169</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_proj</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>WKV calculation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">_</span><span class="p">,</span> <span class="n">seq_length</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="lineno">173</span> <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="lineno">174</span>
<span class="lineno">175</span> <span class="k">if</span> <span class="n">state</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">176</span> <span class="n">num_state</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">key</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">177</span> <span class="n">den_state</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">key</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">178</span> <span class="n">max_state</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">key</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1e38</span>
<span class="lineno">179</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">180</span> <span class="n">num_state</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="n">NUM_STATE</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">181</span> <span class="n">den_state</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="n">DEN_STATE</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">182</span> <span class="n">max_state</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="n">MAX_STATE</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">183</span>
<span class="lineno">184</span> <span class="n">time_decay</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">time_decay</span><span class="p">)</span>
<span class="lineno">185</span>
<span class="lineno">186</span> <span class="k">for</span> <span class="n">current_index</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">seq_length</span><span class="p">):</span>
<span class="lineno">187</span> <span class="n">current_key</span> <span class="o">=</span> <span class="n">key</span><span class="p">[:,</span> <span class="n">current_index</span><span class="p">]</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
<span class="lineno">188</span> <span class="n">current_value</span> <span class="o">=</span> <span class="n">value</span><span class="p">[:,</span> <span class="n">current_index</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord mathnormal" style="margin-right:0.02691em;">w</span><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.8550919999999997em;vertical-align:-0.6433849999999999em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.2117069999999999em;"><span style="top:-2.582215em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mop mtight"><span class="mop op-symbol small-op mtight" style="position:relative;top:-0.0000050000000000050004em;"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8646357142857142em;"><span style="top:-2.177714285714286em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-2.9043214285714285em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.3222857142857143em;"><span></span></span></span></span></span></span><span class="mspace mtight" style="margin-right:0.19516666666666668em;"></span><span class="mord mtight"><span class="mord mathnormal mtight">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8682642857142857em;"><span style="top:-2.868264285714286em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5357142857142856em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight"></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight">t</span><span class="mbin mtight"></span><span class="mord mtight">1</span><span class="mbin mtight"></span><span class="mord mathnormal mtight">i</span><span class="mclose mtight">)</span><span class="mord mathnormal mtight" style="margin-right:0.02691em;">w</span><span class="mbin mtight">+</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3448em;margin-left:-0.03148em;margin-right:0.1em;"><span class="pstrut" style="height:2.65952em;"></span><span class="mord mathnormal mtight">i</span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.31472em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mbin mtight">+</span><span class="mord mtight"><span class="mord mathnormal mtight">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7968357142857142em;"><span style="top:-2.800807142857143em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">u</span><span class="mbin mtight">+</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3447999999999998em;margin-left:-0.03148em;margin-right:0.1em;"><span class="pstrut" style="height:2.61508em;"></span><span class="mord mathnormal mtight">t</span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.27027999999999996em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.5350070000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mop mtight"><span class="mop op-symbol small-op mtight" style="position:relative;top:-0.0000050000000000050004em;"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8913142857142857em;"><span style="top:-2.1785614285714283em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.32143857142857146em;"><span></span></span></span></span></span></span><span class="mspace mtight" style="margin-right:0.19516666666666668em;"></span><span class="mord mtight"><span class="mord mathnormal mtight">d</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9667142857142856em;"><span style="top:-2.9667142857142856em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5357142857142856em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight"></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight">t</span><span class="mbin mtight"></span><span class="mord mtight">1</span><span class="mbin mtight"></span><span class="mord mathnormal mtight">i</span><span class="mclose mtight">)</span><span class="mord mathnormal mtight" style="margin-right:0.02691em;">w</span><span class="mbin mtight">+</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3448em;margin-left:-0.03148em;margin-right:0.1em;"><span class="pstrut" style="height:2.65952em;"></span><span class="mord mathnormal mtight">i</span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.31472em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mbin mtight">+</span><span class="mord mtight"><span class="mord mathnormal mtight">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">u</span><span class="mbin mtight">+</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3447999999999998em;margin-left:-0.03148em;margin-right:0.1em;"><span class="pstrut" style="height:2.61508em;"></span><span class="mord mathnormal mtight">t</span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.27027999999999996em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.29634285714285713em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.6433849999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">max_for_output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">max_state</span><span class="p">,</span> <span class="n">current_key</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_first</span><span class="p">)</span>
<span class="lineno">192</span> <span class="n">e1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">max_state</span> <span class="o">-</span> <span class="n">max_for_output</span><span class="p">)</span>
<span class="lineno">193</span> <span class="n">e2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">current_key</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_first</span> <span class="o">-</span> <span class="n">max_for_output</span><span class="p">)</span>
<span class="lineno">194</span> <span class="n">numerator</span> <span class="o">=</span> <span class="n">e1</span> <span class="o">*</span> <span class="n">num_state</span> <span class="o">+</span> <span class="n">e2</span> <span class="o">*</span> <span class="n">current_value</span>
<span class="lineno">195</span> <span class="n">denominator</span> <span class="o">=</span> <span class="n">e1</span> <span class="o">*</span> <span class="n">den_state</span> <span class="o">+</span> <span class="n">e2</span>
<span class="lineno">196</span> <span class="n">output</span><span class="p">[:,</span> <span class="n">current_index</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Update state for next iteration </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">max_for_state</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">max_state</span> <span class="o">+</span> <span class="n">time_decay</span><span class="p">,</span> <span class="n">current_key</span><span class="p">)</span>
<span class="lineno">200</span> <span class="n">e1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">max_state</span> <span class="o">+</span> <span class="n">time_decay</span> <span class="o">-</span> <span class="n">max_for_state</span><span class="p">)</span>
<span class="lineno">201</span> <span class="n">e2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">current_key</span> <span class="o">-</span> <span class="n">max_for_state</span><span class="p">)</span>
<span class="lineno">202</span> <span class="n">num_state</span> <span class="o">=</span> <span class="n">e1</span> <span class="o">*</span> <span class="n">num_state</span> <span class="o">+</span> <span class="n">e2</span> <span class="o">*</span> <span class="n">current_value</span>
<span class="lineno">203</span> <span class="n">den_state</span> <span class="o">=</span> <span class="n">e1</span> <span class="o">*</span> <span class="n">den_state</span> <span class="o">+</span> <span class="n">e2</span>
<span class="lineno">204</span> <span class="n">max_state</span> <span class="o">=</span> <span class="n">max_for_state</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>update states </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="n">NUM_STATE</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">num_state</span>
<span class="lineno">208</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="n">DEN_STATE</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">den_state</span>
<span class="lineno">209</span> <span class="n">state</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_id</span><span class="p">,</span> <span class="p">:,</span> <span class="n">MAX_STATE</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">max_state</span>
<span class="lineno">210</span> <span class="n">wkv</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">wkv_function</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">use_customized_cuda_kernel</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">use_customized_cuda_kernel</span><span class="p">,</span>
<span class="lineno">211</span> <span class="n">state</span><span class="o">=</span><span class="n">state</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">o</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">o</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02691em;">w</span><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="n">rwkv</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">receptance</span><span class="p">)</span> <span class="o">*</span> <span class="n">wkv</span>
<span class="lineno">215</span> <span class="n">rwkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_proj</span><span class="p">(</span><span class="n">rwkv</span><span class="p">)</span>
<span class="lineno">216</span>
<span class="lineno">217</span> <span class="k">return</span> <span class="n">rwkv</span><span class="p">,</span> <span class="n">state</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<h2>RWKV block element</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span><span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">225</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="n">layer_id</span><span class="p">):</span>
<span class="lineno">226</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">227</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln_1</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
<span class="lineno">228</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">TimeMixing</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">layer_id</span><span class="p">)</span>
<span class="lineno">229</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln_2</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
<span class="lineno">230</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">ChannelMixing</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">layer_id</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>state: <a href="batch_size, 5 , n_embd">batch_size, 5 , n_embd</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>time mixing </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">236</span> <span class="n">residual</span> <span class="o">=</span> <span class="n">x</span>
<span class="lineno">237</span> <span class="n">x</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ln_1</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">state</span><span class="o">=</span><span class="n">state</span><span class="p">)</span>
<span class="lineno">238</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">residual</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>channel mixing </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">241</span> <span class="n">residual</span> <span class="o">=</span> <span class="n">x</span>
<span class="lineno">242</span> <span class="n">x</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ln_2</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">state</span><span class="o">=</span><span class="n">state</span><span class="p">)</span>
<span class="lineno">243</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">residual</span>
<span class="lineno">244</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">state</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<h2>RWKV</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">247</span><span class="k">class</span> <span class="nc">RWKV</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">251</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="n">lr_init</span><span class="o">=</span><span class="mf">0.0008</span><span class="p">):</span>
<span class="lineno">252</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">253</span> <span class="k">assert</span> <span class="n">config</span><span class="o">.</span><span class="n">vocab_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="lineno">254</span> <span class="k">assert</span> <span class="n">config</span><span class="o">.</span><span class="n">block_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="lineno">255</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
<span class="lineno">256</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr_init</span> <span class="o">=</span> <span class="n">lr_init</span> <span class="c1">## used to initialize embedding parameters</span>
<span class="lineno">257</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layer</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">n_layer</span>
<span class="lineno">258</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_embd</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">n_embd</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Initiate model layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">261</span> <span class="bp">self</span><span class="o">.</span><span class="n">rwkv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleDict</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span>
<span class="lineno">262</span> <span class="n">wte</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">n_embd</span><span class="p">),</span>
<span class="lineno">263</span> <span class="n">ln_p</span><span class="o">=</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">bias</span><span class="p">),</span>
<span class="lineno">264</span> <span class="n">h</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">Block</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">layer_id</span><span class="p">)</span> <span class="k">for</span> <span class="n">layer_id</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">n_layer</span><span class="p">)]),</span>
<span class="lineno">265</span> <span class="n">ln_f</span><span class="o">=</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">bias</span><span class="p">),</span>
<span class="lineno">266</span> <span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Output linear layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">269</span> <span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">271</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">return_state</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="lineno">272</span> <span class="n">b</span><span class="p">,</span> <span class="n">t</span> <span class="o">=</span> <span class="n">idx</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="lineno">273</span> <span class="k">assert</span> <span class="n">t</span> <span class="o">&lt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">block_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Cannot forward sequence of length </span><span class="si">{</span><span class="n">t</span><span class="si">}</span><span class="s2">, block size is only </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">block_size</span><span class="si">}</span><span class="s2">&quot;</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Embedding Layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rwkv</span><span class="o">.</span><span class="n">wte</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Layer Norm </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">279</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rwkv</span><span class="o">.</span><span class="n">ln_p</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>RWKV Blocks </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="k">for</span> <span class="n">block_idx</span><span class="p">,</span> <span class="n">block</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rwkv</span><span class="o">.</span><span class="n">h</span><span class="p">):</span>
<span class="lineno">283</span> <span class="n">x</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span>
<span class="lineno">284</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rwkv</span><span class="o">.</span><span class="n">ln_f</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Logit Layer and loss Function (for training) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">287</span> <span class="k">if</span> <span class="n">targets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>if we are given some desired targets also calculate the loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">289</span> <span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">290</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">logits</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)),</span> <span class="n">targets</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">291</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span>
<span class="lineno">292</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">L2Wrap</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span>
<span class="lineno">293</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>inference-time mini-optimization: only forward the lm_head on the very last position </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">295</span> <span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">:])</span> <span class="c1"># note: using list [-1] to preserve the time dim</span>
<span class="lineno">296</span> <span class="n">loss</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Return Logits and loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">299</span> <span class="k">if</span> <span class="n">return_state</span><span class="p">:</span>
<span class="lineno">300</span> <span class="k">return</span> <span class="n">logits</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">state</span>
<span class="lineno">301</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">302</span> <span class="k">return</span> <span class="n">logits</span><span class="p">,</span> <span class="n">loss</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>