diff --git a/labml_nn/transformers/rope/__init__.py b/labml_nn/transformers/rope/__init__.py index 7729cbd1..05ab6875 100644 --- a/labml_nn/transformers/rope/__init__.py +++ b/labml_nn/transformers/rope/__init__.py @@ -168,6 +168,9 @@ class RotaryPositionalEmbeddings(nn.Module): # Cache $\cos$ and $\sin$ values self._build_cache(x) + # Sequence length + seq_len = x.shape[0] + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. x_rope, x_pass = x[..., :self.d], x[..., self.d:] @@ -185,7 +188,7 @@ class RotaryPositionalEmbeddings(nn.Module): # \end{align} # # for $i \in {1, 2, ..., \frac{d}{2}}$ - x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) + x_rope = (x_rope * self.cos_cached[:seq_len]) + (neg_half_x * self.sin_cached[:seq_len]) # return torch.cat((x_rope, x_pass), dim=-1)