flash comments

This commit is contained in:
Varuna Jayasiri
2025-07-31 14:49:37 +05:30
parent 1bc2a69803
commit 0ae6e6ae2a
6 changed files with 1041 additions and 899 deletions

File diff suppressed because one or more lines are too long

View File

@ -103,7 +103,7 @@
<span class="lineno">19</span>
<span class="lineno">20</span>
<span class="lineno">21</span><span class="k">def</span> <span class="nf">_test_op</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">kv_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">causal</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
<span class="lineno">22</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Init&#39;</span><span class="p">):</span>
<span class="lineno">22</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Init </span><span class="si">{</span><span class="n">q_seq_len</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">kv_seq_len</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">d_head</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">):</span>
<span class="lineno">23</span> <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">20</span><span class="p">)</span>
<span class="lineno">24</span> <span class="n">q</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">q_seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span>
<span class="lineno">25</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">())</span>
@ -200,13 +200,12 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span><span class="k">def</span> <span class="nf">_perf_triton_fn</span><span class="p">(</span><span class="o">*</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span>
<span class="lineno">92</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">causal</span><span class="p">,</span> <span class="p">):</span>
<span class="lineno">93</span> <span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span> <span class="o">*</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">94</span> <span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">95</span> <span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">96</span> <span class="n">sm_scale</span> <span class="o">=</span> <span class="n">d_head</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span>
<span class="lineno">97</span> <span class="k">return</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">causal</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">91</span><span class="k">def</span> <span class="nf">_perf_triton_fn</span><span class="p">(</span><span class="o">*</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">causal</span><span class="p">):</span>
<span class="lineno">92</span> <span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span> <span class="o">*</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">93</span> <span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">94</span> <span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">95</span> <span class="n">sm_scale</span> <span class="o">=</span> <span class="n">d_head</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span>
<span class="lineno">96</span> <span class="k">return</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">causal</span><span class="p">,</span> <span class="n">sm_scale</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -217,13 +216,12 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span><span class="k">def</span> <span class="nf">_perf_flash</span><span class="p">(</span><span class="o">*</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">causal</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span>
<span class="lineno">101</span> <span class="n">dtype</span><span class="p">):</span>
<span class="lineno">102</span> <span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">k_heads</span> <span class="o">*</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">103</span> <span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">104</span> <span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">105</span> <span class="kn">from</span> <span class="nn">flash_attn</span> <span class="kn">import</span> <span class="n">flash_attn_func</span>
<span class="lineno">106</span> <span class="k">return</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">flash_attn_func</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">causal</span><span class="o">=</span><span class="n">causal</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">99</span><span class="k">def</span> <span class="nf">_perf_flash</span><span class="p">(</span><span class="o">*</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">causal</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
<span class="lineno">100</span> <span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">k_heads</span> <span class="o">*</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">101</span> <span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">102</span> <span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">d_head</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">103</span> <span class="kn">from</span> <span class="nn">flash_attn</span> <span class="kn">import</span> <span class="n">flash_attn_func</span>
<span class="lineno">104</span> <span class="k">return</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">flash_attn_func</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">causal</span><span class="o">=</span><span class="n">causal</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -234,22 +232,22 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span><span class="k">def</span> <span class="nf">_perf_fn</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">causal</span><span class="p">,</span> <span class="n">is_bwd</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
<span class="lineno">110</span> <span class="k">if</span> <span class="n">is_bwd</span><span class="p">:</span>
<span class="lineno">111</span> <span class="n">o</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
<span class="lineno">112</span> <span class="n">do</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
<span class="lineno">113</span> <span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">114</span> <span class="n">ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span>
<span class="lineno">115</span>
<span class="lineno">116</span> <span class="n">flops_per_matmul</span> <span class="o">=</span> <span class="mf">2.0</span> <span class="o">*</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">seq_len</span> <span class="o">*</span> <span class="n">seq_len</span> <span class="o">*</span> <span class="n">d_head</span>
<span class="lineno">117</span> <span class="n">total_flops</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">flops_per_matmul</span>
<span class="lineno">118</span> <span class="k">if</span> <span class="n">causal</span><span class="p">:</span>
<span class="lineno">119</span> <span class="n">total_flops</span> <span class="o">*=</span> <span class="mf">0.5</span>
<span class="lineno">120</span> <span class="k">if</span> <span class="n">is_bwd</span><span class="p">:</span>
<span class="lineno">121</span> <span class="n">total_flops</span> <span class="o">*=</span> <span class="mf">2.5</span> <span class="c1"># 2.0(bwd) + 0.5(recompute)</span>
<span class="lineno">122</span>
<span class="lineno">123</span> <span class="n">tf_ps</span> <span class="o">=</span> <span class="n">total_flops</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
<span class="lineno">124</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">((</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">logger</span><span class="o">.</span><span class="n">Text</span><span class="o">.</span><span class="n">key</span><span class="p">),</span> <span class="s1">&#39;: &#39;</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">ms</span><span class="w"> </span><span class="si">:</span><span class="s1">,.1f</span><span class="si">}</span><span class="s1">ms&#39;</span><span class="p">,</span> <span class="s1">&#39; &#39;</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">tf_ps</span><span class="w"> </span><span class="si">:</span><span class="s1">,.2f</span><span class="si">}</span><span class="s1">TFps&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">107</span><span class="k">def</span> <span class="nf">_perf_fn</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">k_heads</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">,</span> <span class="n">causal</span><span class="p">,</span> <span class="n">is_bwd</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
<span class="lineno">108</span> <span class="k">if</span> <span class="n">is_bwd</span><span class="p">:</span>
<span class="lineno">109</span> <span class="n">o</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
<span class="lineno">110</span> <span class="n">do</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">o</span><span class="p">)</span>
<span class="lineno">111</span> <span class="n">fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">do</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">112</span> <span class="n">ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span>
<span class="lineno">113</span>
<span class="lineno">114</span> <span class="n">flops_per_matmul</span> <span class="o">=</span> <span class="mf">2.0</span> <span class="o">*</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">k_heads</span> <span class="o">*</span> <span class="n">n_groups</span> <span class="o">*</span> <span class="n">seq_len</span> <span class="o">*</span> <span class="n">seq_len</span> <span class="o">*</span> <span class="n">d_head</span>
<span class="lineno">115</span> <span class="n">total_flops</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">flops_per_matmul</span>
<span class="lineno">116</span> <span class="k">if</span> <span class="n">causal</span><span class="p">:</span>
<span class="lineno">117</span> <span class="n">total_flops</span> <span class="o">*=</span> <span class="mf">0.5</span>
<span class="lineno">118</span> <span class="k">if</span> <span class="n">is_bwd</span><span class="p">:</span>
<span class="lineno">119</span> <span class="n">total_flops</span> <span class="o">*=</span> <span class="mf">2.5</span> <span class="c1"># 2.0(bwd) + 0.5(recompute)</span>
<span class="lineno">120</span>
<span class="lineno">121</span> <span class="n">tf_ps</span> <span class="o">=</span> <span class="n">total_flops</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
<span class="lineno">122</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">((</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">logger</span><span class="o">.</span><span class="n">Text</span><span class="o">.</span><span class="n">key</span><span class="p">),</span> <span class="s1">&#39;: &#39;</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">ms</span><span class="w"> </span><span class="si">:</span><span class="s1">,.1f</span><span class="si">}</span><span class="s1">ms&#39;</span><span class="p">,</span> <span class="s1">&#39; &#39;</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">tf_ps</span><span class="w"> </span><span class="si">:</span><span class="s1">,.2f</span><span class="si">}</span><span class="s1">TFps&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -260,11 +258,11 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span>
<span class="lineno">128</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span>
<span class="lineno">129</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">130</span>
<span class="lineno">131</span> <span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bfloat16</span></pre></div>
<div class="highlight"><pre><span class="lineno">125</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span>
<span class="lineno">126</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span>
<span class="lineno">127</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">128</span>
<span class="lineno">129</span> <span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -276,32 +274,32 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">_test_op</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">135</span> <span class="n">_test_op</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">4096</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">136</span> <span class="n">_test_op</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">137</span> <span class="n">_test_op</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">138</span>
<span class="lineno">139</span> <span class="n">_conf</span> <span class="o">=</span> <span class="p">{</span>
<span class="lineno">140</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span>
<span class="lineno">141</span> <span class="s1">&#39;k_heads&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="lineno">142</span> <span class="s1">&#39;n_groups&#39;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
<span class="lineno">143</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">2048</span><span class="p">,</span>
<span class="lineno">144</span> <span class="s1">&#39;d_head&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
<span class="lineno">145</span> <span class="p">}</span>
<span class="lineno">146</span>
<span class="lineno">147</span> <span class="k">for</span> <span class="n">_causal</span> <span class="ow">in</span> <span class="p">[</span><span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">]:</span>
<span class="lineno">148</span> <span class="k">for</span> <span class="n">is_bwd</span> <span class="ow">in</span> <span class="p">[</span><span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">]:</span>
<span class="lineno">149</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="s2">&quot;Causal&quot;</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="n">_causal</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="s2">&quot;Non-causal&quot;</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="s2">&quot; Backward&quot;</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="n">is_bwd</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="s2">&quot;&quot;</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">logger</span><span class="o">.</span><span class="n">Text</span><span class="o">.</span><span class="n">title</span><span class="p">)</span>
<span class="lineno">150</span> <span class="n">_perf_fn</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;flash&#39;</span><span class="p">,</span> <span class="n">_perf_flash</span><span class="p">(</span><span class="n">causal</span><span class="o">=</span><span class="n">_causal</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="o">**</span><span class="n">_conf</span><span class="p">),</span>
<span class="lineno">151</span> <span class="n">is_bwd</span><span class="o">=</span><span class="n">is_bwd</span><span class="p">,</span>
<span class="lineno">152</span> <span class="n">causal</span><span class="o">=</span><span class="n">_causal</span><span class="p">,</span> <span class="o">**</span><span class="n">_conf</span><span class="p">)</span>
<span class="lineno">153</span> <span class="n">_perf_fn</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;triton&#39;</span><span class="p">,</span> <span class="n">_perf_triton_fn</span><span class="p">(</span><span class="n">causal</span><span class="o">=</span><span class="n">_causal</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="o">**</span><span class="n">_conf</span><span class="p">),</span>
<span class="lineno">154</span> <span class="n">is_bwd</span><span class="o">=</span><span class="n">is_bwd</span><span class="p">,</span>
<span class="lineno">155</span> <span class="n">causal</span><span class="o">=</span><span class="n">_causal</span><span class="p">,</span> <span class="o">**</span><span class="n">_conf</span><span class="p">)</span>
<span class="lineno">156</span>
<span class="lineno">157</span>
<span class="lineno">158</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="lineno">159</span> <span class="n">_test</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">_test_op</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">133</span> <span class="n">_test_op</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">2001</span><span class="p">,</span> <span class="mi">4001</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">134</span> <span class="n">_test_op</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">2048</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">135</span> <span class="n">_test_op</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">2001</span><span class="p">,</span> <span class="mi">4001</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">136</span>
<span class="lineno">137</span> <span class="n">_conf</span> <span class="o">=</span> <span class="p">{</span>
<span class="lineno">138</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span>
<span class="lineno">139</span> <span class="s1">&#39;k_heads&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="lineno">140</span> <span class="s1">&#39;n_groups&#39;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
<span class="lineno">141</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">2048</span><span class="p">,</span>
<span class="lineno">142</span> <span class="s1">&#39;d_head&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
<span class="lineno">143</span> <span class="p">}</span>
<span class="lineno">144</span>
<span class="lineno">145</span> <span class="k">for</span> <span class="n">_causal</span> <span class="ow">in</span> <span class="p">[</span><span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">]:</span>
<span class="lineno">146</span> <span class="k">for</span> <span class="n">is_bwd</span> <span class="ow">in</span> <span class="p">[</span><span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">]:</span>
<span class="lineno">147</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="s2">&quot;Causal&quot;</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="n">_causal</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="s2">&quot;Non-causal&quot;</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="s2">&quot; Backward&quot;</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="n">is_bwd</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="s2">&quot;&quot;</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">logger</span><span class="o">.</span><span class="n">Text</span><span class="o">.</span><span class="n">title</span><span class="p">)</span>
<span class="lineno">148</span> <span class="n">_perf_fn</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;flash&#39;</span><span class="p">,</span> <span class="n">_perf_flash</span><span class="p">(</span><span class="n">causal</span><span class="o">=</span><span class="n">_causal</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="o">**</span><span class="n">_conf</span><span class="p">),</span>
<span class="lineno">149</span> <span class="n">is_bwd</span><span class="o">=</span><span class="n">is_bwd</span><span class="p">,</span>
<span class="lineno">150</span> <span class="n">causal</span><span class="o">=</span><span class="n">_causal</span><span class="p">,</span> <span class="o">**</span><span class="n">_conf</span><span class="p">)</span>
<span class="lineno">151</span> <span class="n">_perf_fn</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;triton&#39;</span><span class="p">,</span> <span class="n">_perf_triton_fn</span><span class="p">(</span><span class="n">causal</span><span class="o">=</span><span class="n">_causal</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="o">**</span><span class="n">_conf</span><span class="p">),</span>
<span class="lineno">152</span> <span class="n">is_bwd</span><span class="o">=</span><span class="n">is_bwd</span><span class="p">,</span>
<span class="lineno">153</span> <span class="n">causal</span><span class="o">=</span><span class="n">_causal</span><span class="p">,</span> <span class="o">**</span><span class="n">_conf</span><span class="p">)</span>
<span class="lineno">154</span>
<span class="lineno">155</span>
<span class="lineno">156</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="lineno">157</span> <span class="n">_test</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>