From 47d4231a730dff11594e37a256a57774018dc051 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Fri, 18 Jul 2025 10:43:38 +0530 Subject: [PATCH] seq length in rope --- labml_nn/transformers/rope/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/labml_nn/transformers/rope/__init__.py b/labml_nn/transformers/rope/__init__.py index 7729cbd1..05ab6875 100644 --- a/labml_nn/transformers/rope/__init__.py +++ b/labml_nn/transformers/rope/__init__.py @@ -168,6 +168,9 @@ class RotaryPositionalEmbeddings(nn.Module): # Cache $\cos$ and $\sin$ values self._build_cache(x) + # Sequence length + seq_len = x.shape[0] + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. x_rope, x_pass = x[..., :self.d], x[..., self.d:] @@ -185,7 +188,7 @@ class RotaryPositionalEmbeddings(nn.Module): # \end{align} # # for $i \in {1, 2, ..., \frac{d}{2}}$ - x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) + x_rope = (x_rope * self.cos_cached[:seq_len]) + (neg_half_x * self.sin_cached[:seq_len]) # return torch.cat((x_rope, x_pass), dim=-1)