mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
seq length in rope
This commit is contained in:
@ -168,6 +168,9 @@ class RotaryPositionalEmbeddings(nn.Module):
|
||||
# Cache $\cos$ and $\sin$ values
|
||||
self._build_cache(x)
|
||||
|
||||
# Sequence length
|
||||
seq_len = x.shape[0]
|
||||
|
||||
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
||||
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
|
||||
|
||||
@ -185,7 +188,7 @@ class RotaryPositionalEmbeddings(nn.Module):
|
||||
# \end{align}
|
||||
#
|
||||
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
||||
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
|
||||
x_rope = (x_rope * self.cos_cached[:seq_len]) + (neg_half_x * self.sin_cached[:seq_len])
|
||||
|
||||
#
|
||||
return torch.cat((x_rope, x_pass), dim=-1)
|
||||
|
Reference in New Issue
Block a user