diff --git a/docs/diffusion/stable_diffusion/model/unet_attention.html b/docs/diffusion/stable_diffusion/model/unet_attention.html
index 0a717616..e9041d74 100644
--- a/docs/diffusion/stable_diffusion/model/unet_attention.html
+++ b/docs/diffusion/stable_diffusion/model/unet_attention.html
@@ -603,9 +603,9 @@
174 v = self.to_v(cond)
175
176 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
-177 self.flash_attention(q, k, v)
+177 return self.flash_attention(q, k, v)
178 else:
-179 self.normal_attention(q, k, v)
+179 return self.normal_attention(q, k, v)
diff --git a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py
index 4042d500..77e018e6 100644
--- a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py
+++ b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py
@@ -174,9 +174,9 @@ class CrossAttention(nn.Module):
v = self.to_v(cond)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
- self.flash_attention(q, k, v)
+ return self.flash_attention(q, k, v)
else:
- self.normal_attention(q, k, v)
+ return self.normal_attention(q, k, v)
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""