mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 01:13:00 +08:00
fix ddpm attn
This commit is contained in:
@ -651,7 +651,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-42'>
|
<div class='section' id='section-42'>
|
||||||
|
@ -547,7 +547,7 @@
|
|||||||
|
|
||||||
<url>
|
<url>
|
||||||
<loc>https://nn.labml.ai/diffusion/ddpm/index.html</loc>
|
<loc>https://nn.labml.ai/diffusion/ddpm/index.html</loc>
|
||||||
<lastmod>2022-09-07T16:30:00+00:00</lastmod>
|
<lastmod>2022-09-11T16:30:00+00:00</lastmod>
|
||||||
<priority>1.00</priority>
|
<priority>1.00</priority>
|
||||||
</url>
|
</url>
|
||||||
|
|
||||||
|
@ -185,7 +185,7 @@ class AttentionBlock(Module):
|
|||||||
# Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
|
# Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
|
||||||
attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
|
attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
|
||||||
# Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
|
# Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
|
||||||
attn = attn.softmax(dim=1)
|
attn = attn.softmax(dim=2)
|
||||||
# Multiply by values
|
# Multiply by values
|
||||||
res = torch.einsum('bijh,bjhd->bihd', attn, v)
|
res = torch.einsum('bijh,bjhd->bihd', attn, v)
|
||||||
# Reshape to `[batch_size, seq, n_heads * d_k]`
|
# Reshape to `[batch_size, seq, n_heads * d_k]`
|
||||||
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
|
|||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='labml-nn',
|
name='labml-nn',
|
||||||
version='0.4.130',
|
version='0.4.131',
|
||||||
author="Varuna Jayasiri, Nipun Wijerathne",
|
author="Varuna Jayasiri, Nipun Wijerathne",
|
||||||
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
||||||
description="🧑🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, diffusion, etc. 🧠",
|
description="🧑🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, diffusion, etc. 🧠",
|
||||||
|
Reference in New Issue
Block a user