mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 09:38:56 +08:00
fix
This commit is contained in:
@ -164,8 +164,10 @@ class CrossAttention(nn.Module):
|
||||
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
||||
"""
|
||||
|
||||
has_cond = cond is not None
|
||||
|
||||
# If `cond` is `None` we perform self attention
|
||||
if cond is None:
|
||||
if not has_cond:
|
||||
cond = x
|
||||
|
||||
# Get query, key and value vectors
|
||||
@ -173,9 +175,7 @@ class CrossAttention(nn.Module):
|
||||
k = self.to_k(cond)
|
||||
v = self.to_v(cond)
|
||||
|
||||
print('use flash', CrossAttention.use_flash_attention, self.flash)
|
||||
|
||||
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
|
||||
if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
|
||||
return self.flash_attention(q, k, v)
|
||||
else:
|
||||
return self.normal_attention(q, k, v)
|
||||
|
||||
Reference in New Issue
Block a user