From 09d09379c2169eac06662e17cb9969dc6e48e36a Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 20 Jun 2024 12:53:09 +0530 Subject: [PATCH] fix value pe double rotation --- docs/transformers/rope/value_pe/index.html | 2 +- labml_nn/transformers/rope/value_pe/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/transformers/rope/value_pe/index.html b/docs/transformers/rope/value_pe/index.html index f409d1ad..9b35ca38 100644 --- a/docs/transformers/rope/value_pe/index.html +++ b/docs/transformers/rope/value_pe/index.html @@ -412,7 +412,7 @@ M834 80h400000v40h-400000z">
234        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
+
234        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
diff --git a/labml_nn/transformers/rope/value_pe/__init__.py b/labml_nn/transformers/rope/value_pe/__init__.py index a87c4462..8aadeab8 100644 --- a/labml_nn/transformers/rope/value_pe/__init__.py +++ b/labml_nn/transformers/rope/value_pe/__init__.py @@ -231,7 +231,7 @@ class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention): # Multiply by values # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ - x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value)) + x = torch.einsum("ijbh,jbhd->ibhd", attn, value) # Rotate in the opposite direction so that each embedding hold the relative positions x = self.value_reverse_rotary_pe(x)