This commit is contained in:
Varuna Jayasiri
2022-09-24 14:35:28 +05:30
parent 37b30cfc3f
commit 160f25a938
2 changed files with 4 additions and 4 deletions

View File

@ -174,9 +174,9 @@ class CrossAttention(nn.Module):
v = self.to_v(cond)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
self.flash_attention(q, k, v)
return self.flash_attention(q, k, v)
else:
self.normal_attention(q, k, v)
return self.normal_attention(q, k, v)
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""