This commit is contained in:
Varuna Jayasiri
2022-09-24 14:40:06 +05:30
parent eb92824e58
commit f4b2d46925
2 changed files with 2 additions and 2 deletions

View File

@ -173,7 +173,7 @@ class CrossAttention(nn.Module):
k = self.to_k(cond)
v = self.to_v(cond)
print('use flash', CrossAttention.use_flash_attention)
print('use flash', CrossAttention.use_flash_attention, self.flash)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
return self.flash_attention(q, k, v)