This commit is contained in:
Varuna Jayasiri
2022-09-24 14:39:10 +05:30
parent de36f9b6be
commit eb92824e58
2 changed files with 59 additions and 55 deletions

View File

@ -173,6 +173,8 @@ class CrossAttention(nn.Module):
k = self.to_k(cond)
v = self.to_v(cond)
print('use flash', CrossAttention.use_flash_attention)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
return self.flash_attention(q, k, v)
else:
@ -186,7 +188,7 @@ class CrossAttention(nn.Module):
"""
print('flash')
# Get batch size and number of elements along sequence axis (width * height)
batch_size, seq_len, _ = q.shape