seq length in rope

This commit is contained in:
Varuna Jayasiri
2025-07-18 10:43:38 +05:30
parent f6d77c36b2
commit 47d4231a73

View File

@ -168,6 +168,9 @@ class RotaryPositionalEmbeddings(nn.Module):
# Cache $\cos$ and $\sin$ values
self._build_cache(x)
# Sequence length
seq_len = x.shape[0]
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
@ -185,7 +188,7 @@ class RotaryPositionalEmbeddings(nn.Module):
# \end{align}
#
# for $i \in {1, 2, ..., \frac{d}{2}}$
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
x_rope = (x_rope * self.cos_cached[:seq_len]) + (neg_half_x * self.sin_cached[:seq_len])
#
return torch.cat((x_rope, x_pass), dim=-1)