seq length in rope

This commit is contained in:
Varuna Jayasiri
2025-07-18 10:43:38 +05:30
parent f6d77c36b2
commit 47d4231a73

View File

@ -168,6 +168,9 @@ class RotaryPositionalEmbeddings(nn.Module):
# Cache $\cos$ and $\sin$ values # Cache $\cos$ and $\sin$ values
self._build_cache(x) self._build_cache(x)
# Sequence length
seq_len = x.shape[0]
# Split the features, we can choose to apply rotary embeddings only to a partial set of features. # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., :self.d], x[..., self.d:] x_rope, x_pass = x[..., :self.d], x[..., self.d:]
@ -185,7 +188,7 @@ class RotaryPositionalEmbeddings(nn.Module):
# \end{align} # \end{align}
# #
# for $i \in {1, 2, ..., \frac{d}{2}}$ # for $i \in {1, 2, ..., \frac{d}{2}}$
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) x_rope = (x_rope * self.cos_cached[:seq_len]) + (neg_half_x * self.sin_cached[:seq_len])
# #
return torch.cat((x_rope, x_pass), dim=-1) return torch.cat((x_rope, x_pass), dim=-1)