diff --git a/docs/transformers/rope/value_pe/index.html b/docs/transformers/rope/value_pe/index.html index f409d1ad..9b35ca38 100644 --- a/docs/transformers/rope/value_pe/index.html +++ b/docs/transformers/rope/value_pe/index.html @@ -412,7 +412,7 @@ M834 80h400000v40h-400000z">
234        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
+
234        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
diff --git a/labml_nn/transformers/rope/value_pe/__init__.py b/labml_nn/transformers/rope/value_pe/__init__.py index a87c4462..8aadeab8 100644 --- a/labml_nn/transformers/rope/value_pe/__init__.py +++ b/labml_nn/transformers/rope/value_pe/__init__.py @@ -231,7 +231,7 @@ class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention): # Multiply by values # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ - x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value)) + x = torch.einsum("ijbh,jbhd->ibhd", attn, value) # Rotate in the opposite direction so that each embedding hold the relative positions x = self.value_reverse_rotary_pe(x)