diff --git a/docs/diffusion/stable_diffusion/model/unet_attention.html b/docs/diffusion/stable_diffusion/model/unet_attention.html index f5e5331a..b7118b0a 100644 --- a/docs/diffusion/stable_diffusion/model/unet_attention.html +++ b/docs/diffusion/stable_diffusion/model/unet_attention.html @@ -602,10 +602,12 @@ 173 k = self.to_k(cond) 174 v = self.to_v(cond) 175 -176 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: -177 return self.flash_attention(q, k, v) -178 else: -179 return self.normal_attention(q, k, v) +176 print('use flash', CrossAttention.use_flash_attention) +177 +178 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: +179 return self.flash_attention(q, k, v) +180 else: +181 return self.normal_attention(q, k, v)
@@ -625,7 +627,7 @@
-
181    def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+
183    def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
@@ -636,7 +638,7 @@
-
188        print('flash')
+
190        print('flash')
@@ -647,7 +649,7 @@ MarkdownException + Italic: not ending with *
-
191        batch_size, seq_len, _ = q.shape
+
193        batch_size, seq_len, _ = q.shape
@@ -663,7 +665,7 @@
-
195        qkv = torch.stack((q, k, v), dim=2)
+
197        qkv = torch.stack((q, k, v), dim=2)
@@ -675,7 +677,7 @@
-
197        qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
+
199        qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
@@ -690,14 +692,14 @@
-
201        if self.d_head <= 32:
-202            pad = 32 - self.d_head
-203        elif self.d_head <= 64:
-204            pad = 64 - self.d_head
-205        elif self.d_head <= 128:
-206            pad = 128 - self.d_head
-207        else:
-208            raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
+
203        if self.d_head <= 32:
+204            pad = 32 - self.d_head
+205        elif self.d_head <= 64:
+206            pad = 64 - self.d_head
+207        elif self.d_head <= 128:
+208            pad = 128 - self.d_head
+209        else:
+210            raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
@@ -709,8 +711,8 @@
-
211        if pad:
-212            qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)
+
213        if pad:
+214            qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)
@@ -721,7 +723,7 @@ KeyError + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
-
217        out, _ = self.flash(qkv)
+
219        out, _ = self.flash(qkv)
@@ -733,7 +735,7 @@
-
219        out = out[:, :, :, :self.d_head]
+
221        out = out[:, :, :, :self.d_head]
@@ -746,7 +748,7 @@
-
221        out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
+
223        out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
@@ -759,7 +761,7 @@
-
224        return self.to_out(out)
+
226        return self.to_out(out)
@@ -779,7 +781,7 @@
-
226    def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+
228    def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
@@ -792,9 +794,9 @@
-
234        q = q.view(*q.shape[:2], self.n_heads, -1)
-235        k = k.view(*k.shape[:2], self.n_heads, -1)
-236        v = v.view(*v.shape[:2], self.n_heads, -1)
+
236        q = q.view(*q.shape[:2], self.n_heads, -1)
+237        k = k.view(*k.shape[:2], self.n_heads, -1)
+238        v = v.view(*v.shape[:2], self.n_heads, -1)
@@ -805,7 +807,7 @@ KeyError + '\\frac{Q K^\\top}{\\sqrt{d_{key}}}'
-
239        attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
+
241        attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
@@ -816,12 +818,12 @@ KeyError + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)'
-
243        if self.is_inplace:
-244            half = attn.shape[0] // 2
-245            attn[half:] = attn[half:].softmax(dim=-1)
-246            attn[:half] = attn[:half].softmax(dim=-1)
-247        else:
-248            attn = attn.softmax(dim=-1)
+
245        if self.is_inplace:
+246            half = attn.shape[0] // 2
+247            attn[half:] = attn[half:].softmax(dim=-1)
+248            attn[:half] = attn[:half].softmax(dim=-1)
+249        else:
+250            attn = attn.softmax(dim=-1)
@@ -832,7 +834,7 @@ KeyError + '\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V'
-
252        out = torch.einsum('bhij,bjhd->bihd', attn, v)
+
254        out = torch.einsum('bhij,bjhd->bihd', attn, v)
@@ -845,7 +847,7 @@
-
254        out = out.reshape(*out.shape[:2], -1)
+
256        out = out.reshape(*out.shape[:2], -1)
@@ -858,7 +860,7 @@
-
256        return self.to_out(out)
+
258        return self.to_out(out)
@@ -870,7 +872,7 @@
-
259class FeedForward(nn.Module):
+
261class FeedForward(nn.Module):
@@ -885,7 +887,7 @@
-
264    def __init__(self, d_model: int, d_mult: int = 4):
+
266    def __init__(self, d_model: int, d_mult: int = 4):
@@ -896,12 +898,12 @@
-
269        super().__init__()
-270        self.net = nn.Sequential(
-271            GeGLU(d_model, d_model * d_mult),
-272            nn.Dropout(0.),
-273            nn.Linear(d_model * d_mult, d_model)
-274        )
+
271        super().__init__()
+272        self.net = nn.Sequential(
+273            GeGLU(d_model, d_model * d_mult),
+274            nn.Dropout(0.),
+275            nn.Linear(d_model * d_mult, d_model)
+276        )
@@ -912,8 +914,8 @@
-
276    def forward(self, x: torch.Tensor):
-277        return self.net(x)
+
278    def forward(self, x: torch.Tensor):
+279        return self.net(x)
@@ -924,7 +926,7 @@ KeyError + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
-
280class GeGLU(nn.Module):
+
282class GeGLU(nn.Module):
@@ -935,8 +937,8 @@
-
287    def __init__(self, d_in: int, d_out: int):
-288        super().__init__()
+
289    def __init__(self, d_in: int, d_out: int):
+290        super().__init__()
@@ -947,7 +949,7 @@ KeyError + 'xW + b'
-
290        self.proj = nn.Linear(d_in, d_out * 2)
+
292        self.proj = nn.Linear(d_in, d_out * 2)
@@ -958,7 +960,7 @@
-
292    def forward(self, x: torch.Tensor):
+
294    def forward(self, x: torch.Tensor):
@@ -969,7 +971,7 @@ KeyError + 'xW + b'
-
294        x, gate = self.proj(x).chunk(2, dim=-1)
+
296        x, gate = self.proj(x).chunk(2, dim=-1)
@@ -980,7 +982,7 @@ KeyError + '\\text{GeGLU}(x) = (xW + b) * \\text{GELU}(xV + c)'
-
296        return x * F.gelu(gate)
+
298        return x * F.gelu(gate)