This commit is contained in:
Varuna Jayasiri
2022-09-24 14:41:55 +05:30
parent f4b2d46925
commit 1a49b753e4
2 changed files with 81 additions and 72 deletions

View File

@ -164,8 +164,10 @@ class CrossAttention(nn.Module):
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
has_cond = cond is not None
# If `cond` is `None` we perform self attention
if cond is None:
if not has_cond:
cond = x
# Get query, key and value vectors
@ -173,9 +175,7 @@ class CrossAttention(nn.Module):
k = self.to_k(cond)
v = self.to_v(cond)
print('use flash', CrossAttention.use_flash_attention, self.flash)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
return self.flash_attention(q, k, v)
else:
return self.normal_attention(q, k, v)