diff --git a/docs/diffusion/stable_diffusion/model/unet_attention.html b/docs/diffusion/stable_diffusion/model/unet_attention.html index f5e5331a..b7118b0a 100644 --- a/docs/diffusion/stable_diffusion/model/unet_attention.html +++ b/docs/diffusion/stable_diffusion/model/unet_attention.html @@ -602,10 +602,12 @@ 173 k = self.to_k(cond) 174 v = self.to_v(cond) 175 -176 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: -177 return self.flash_attention(q, k, v) -178 else: -179 return self.normal_attention(q, k, v) +176 print('use flash', CrossAttention.use_flash_attention) +177 +178 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: +179 return self.flash_attention(q, k, v) +180 else: +181 return self.normal_attention(q, k, v)
181 def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
183 def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
188 print('flash')
190 print('flash')
191 batch_size, seq_len, _ = q.shape
193 batch_size, seq_len, _ = q.shape
195 qkv = torch.stack((q, k, v), dim=2)
197 qkv = torch.stack((q, k, v), dim=2)
197 qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
199 qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
201 if self.d_head <= 32:
-202 pad = 32 - self.d_head
-203 elif self.d_head <= 64:
-204 pad = 64 - self.d_head
-205 elif self.d_head <= 128:
-206 pad = 128 - self.d_head
-207 else:
-208 raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
203 if self.d_head <= 32:
+204 pad = 32 - self.d_head
+205 elif self.d_head <= 64:
+206 pad = 64 - self.d_head
+207 elif self.d_head <= 128:
+208 pad = 128 - self.d_head
+209 else:
+210 raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
211 if pad:
-212 qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)
213 if pad:
+214 qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)
217 out, _ = self.flash(qkv)
219 out, _ = self.flash(qkv)
219 out = out[:, :, :, :self.d_head]
221 out = out[:, :, :, :self.d_head]
221 out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
223 out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
224 return self.to_out(out)
226 return self.to_out(out)
226 def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
228 def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
234 q = q.view(*q.shape[:2], self.n_heads, -1)
-235 k = k.view(*k.shape[:2], self.n_heads, -1)
-236 v = v.view(*v.shape[:2], self.n_heads, -1)
236 q = q.view(*q.shape[:2], self.n_heads, -1)
+237 k = k.view(*k.shape[:2], self.n_heads, -1)
+238 v = v.view(*v.shape[:2], self.n_heads, -1)
239 attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
241 attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
243 if self.is_inplace:
-244 half = attn.shape[0] // 2
-245 attn[half:] = attn[half:].softmax(dim=-1)
-246 attn[:half] = attn[:half].softmax(dim=-1)
-247 else:
-248 attn = attn.softmax(dim=-1)
245 if self.is_inplace:
+246 half = attn.shape[0] // 2
+247 attn[half:] = attn[half:].softmax(dim=-1)
+248 attn[:half] = attn[:half].softmax(dim=-1)
+249 else:
+250 attn = attn.softmax(dim=-1)
252 out = torch.einsum('bhij,bjhd->bihd', attn, v)
254 out = torch.einsum('bhij,bjhd->bihd', attn, v)
254 out = out.reshape(*out.shape[:2], -1)
256 out = out.reshape(*out.shape[:2], -1)
256 return self.to_out(out)
258 return self.to_out(out)
259class FeedForward(nn.Module):
261class FeedForward(nn.Module):
264 def __init__(self, d_model: int, d_mult: int = 4):
266 def __init__(self, d_model: int, d_mult: int = 4):
269 super().__init__()
-270 self.net = nn.Sequential(
-271 GeGLU(d_model, d_model * d_mult),
-272 nn.Dropout(0.),
-273 nn.Linear(d_model * d_mult, d_model)
-274 )
271 super().__init__()
+272 self.net = nn.Sequential(
+273 GeGLU(d_model, d_model * d_mult),
+274 nn.Dropout(0.),
+275 nn.Linear(d_model * d_mult, d_model)
+276 )
276 def forward(self, x: torch.Tensor):
-277 return self.net(x)
278 def forward(self, x: torch.Tensor):
+279 return self.net(x)
280class GeGLU(nn.Module):
282class GeGLU(nn.Module):
287 def __init__(self, d_in: int, d_out: int):
-288 super().__init__()
289 def __init__(self, d_in: int, d_out: int):
+290 super().__init__()
290 self.proj = nn.Linear(d_in, d_out * 2)
292 self.proj = nn.Linear(d_in, d_out * 2)
292 def forward(self, x: torch.Tensor):
294 def forward(self, x: torch.Tensor):
294 x, gate = self.proj(x).chunk(2, dim=-1)
296 x, gate = self.proj(x).chunk(2, dim=-1)
296 return x * F.gelu(gate)
298 return x * F.gelu(gate)