From 1a49b753e40e7efe93c97e1da19147f04f634a05 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 24 Sep 2022 14:41:55 +0530 Subject: [PATCH] fix --- .../model/unet_attention.html | 145 ++++++++++-------- .../stable_diffusion/model/unet_attention.py | 8 +- 2 files changed, 81 insertions(+), 72 deletions(-) diff --git a/docs/diffusion/stable_diffusion/model/unet_attention.html b/docs/diffusion/stable_diffusion/model/unet_attention.html index 4399937e..3bb8e819 100644 --- a/docs/diffusion/stable_diffusion/model/unet_attention.html +++ b/docs/diffusion/stable_diffusion/model/unet_attention.html @@ -579,14 +579,10 @@ -

If cond - is None - we perform self attention

- +
-
168        if cond is None:
-169            cond = x
+
167        has_cond = cond is not None
@@ -594,26 +590,39 @@ +

If cond + is None + we perform self attention

+ +
+
+
170        if not has_cond:
+171            cond = x
+
+ +
+
+

Get query, key and value vectors

-
172        q = self.to_q(x)
-173        k = self.to_k(cond)
-174        v = self.to_v(cond)
-175
-176        print('use flash', CrossAttention.use_flash_attention, self.flash)
+            
174        q = self.to_q(x)
+175        k = self.to_k(cond)
+176        v = self.to_v(cond)
 177
-178        if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
+178        if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
 179            return self.flash_attention(q, k, v)
 180        else:
 181            return self.normal_attention(q, k, v)
-
+
  • q are the query vectors before splitting heads, of shape [batch_size, seq, d_attn] @@ -630,10 +639,10 @@
    183    def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
-
+
@@ -641,10 +650,10 @@
190        print('flash')
-
+
MarkdownException + Italic: not ending with *
@@ -652,10 +661,10 @@
193        batch_size, seq_len, _ = q.shape
-
+

Stack q , k @@ -668,10 +677,10 @@

197        qkv = torch.stack((q, k, v), dim=2)
-
+

Split the heads

@@ -680,10 +689,10 @@
199        qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
-
+

Flash attention works for head sizes 32 , 64 @@ -702,10 +711,10 @@ 210 raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')

-
+

Pad the heads

@@ -715,10 +724,10 @@ 214 qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)
-
+
KeyError + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
@@ -726,10 +735,10 @@
219        out, _ = self.flash(qkv)
-
+

Truncate the extra head size

@@ -738,10 +747,10 @@
221        out = out[:, :, :, :self.d_head]
-
+

Reshape to [batch_size, seq_len, n_heads * d_head]

@@ -751,10 +760,10 @@
223        out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
-
+

Map to [batch_size, height * width, d_model] with a linear layer

@@ -764,10 +773,10 @@
226        return self.to_out(out)
-
+
  • q are the query vectors before splitting heads, of shape [batch_size, seq, d_attn] @@ -784,10 +793,10 @@
    228    def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
-
+

Split them to heads of shape [batch_size, seq_len, n_heads, d_head]

@@ -799,10 +808,10 @@ 238 v = v.view(*v.shape[:2], self.n_heads, -1)
-
+
KeyError + '\\frac{Q K^\\top}{\\sqrt{d_{key}}}'
@@ -810,10 +819,10 @@
241        attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
-
+
KeyError + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)'
@@ -826,10 +835,10 @@ 250 attn = attn.softmax(dim=-1)
-
+
KeyError + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
@@ -837,10 +846,10 @@
254        out = torch.einsum('bhij,bjhd->bihd', attn, v)
-
+

Reshape to [batch_size, height * width, n_heads * d_head]

@@ -850,10 +859,10 @@
256        out = out.reshape(*out.shape[:2], -1)
-
+

Map to [batch_size, height * width, d_model] with a linear layer

@@ -863,10 +872,10 @@
258        return self.to_out(out)
-
+

Feed-Forward Network

@@ -875,10 +884,10 @@
261class FeedForward(nn.Module):
-
+
  • d_model is the input embedding size
  • @@ -890,10 +899,10 @@
    266    def __init__(self, d_model: int, d_mult: int = 4):
-
+
@@ -906,10 +915,10 @@ 276 )
-
+
@@ -918,10 +927,10 @@ 279 return self.net(x)
-
+
KeyError + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
@@ -929,10 +938,10 @@
282class GeGLU(nn.Module):
-
+
@@ -941,10 +950,10 @@ 290 super().__init__()
-
+
KeyError + 'xW + b'
@@ -952,10 +961,10 @@
292        self.proj = nn.Linear(d_in, d_out * 2)
-
+
@@ -963,10 +972,10 @@
294    def forward(self, x: torch.Tensor):
-
+
KeyError + 'xW + b'
@@ -974,10 +983,10 @@
296        x, gate = self.proj(x).chunk(2, dim=-1)
-
+
KeyError + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
diff --git a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py index 79baa603..15f18f5c 100644 --- a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py +++ b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py @@ -164,8 +164,10 @@ class CrossAttention(nn.Module): :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` """ + has_cond = cond is not None + # If `cond` is `None` we perform self attention - if cond is None: + if not has_cond: cond = x # Get query, key and value vectors @@ -173,9 +175,7 @@ class CrossAttention(nn.Module): k = self.to_k(cond) v = self.to_v(cond) - print('use flash', CrossAttention.use_flash_attention, self.flash) - - if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: + if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128: return self.flash_attention(q, k, v) else: return self.normal_attention(q, k, v)