fix ddpm attn

This commit is contained in:
Varuna Jayasiri
2022-09-12 08:30:27 +05:30
parent 4eec7a4e5e
commit 7d1550dd67
4 changed files with 4 additions and 4 deletions

View File

@ -651,7 +651,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>

View File

@ -547,7 +547,7 @@
<url>
<loc>https://nn.labml.ai/diffusion/ddpm/index.html</loc>
<lastmod>2022-09-07T16:30:00+00:00</lastmod>
<lastmod>2022-09-11T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>

View File

@ -185,7 +185,7 @@ class AttentionBlock(Module):
# Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
# Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = attn.softmax(dim=1)
attn = attn.softmax(dim=2)
# Multiply by values
res = torch.einsum('bijh,bjhd->bihd', attn, v)
# Reshape to `[batch_size, seq, n_heads * d_k]`

View File

@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
version='0.4.130',
version='0.4.131',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="🧑‍🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, diffusion, etc. 🧠",