mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
seq length in rope
This commit is contained in:
@ -168,6 +168,9 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|||||||
# Cache $\cos$ and $\sin$ values
|
# Cache $\cos$ and $\sin$ values
|
||||||
self._build_cache(x)
|
self._build_cache(x)
|
||||||
|
|
||||||
|
# Sequence length
|
||||||
|
seq_len = x.shape[0]
|
||||||
|
|
||||||
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
||||||
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
|
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
|
||||||
|
|
||||||
@ -185,7 +188,7 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|||||||
# \end{align}
|
# \end{align}
|
||||||
#
|
#
|
||||||
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
||||||
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
|
x_rope = (x_rope * self.cos_cached[:seq_len]) + (neg_half_x * self.sin_cached[:seq_len])
|
||||||
|
|
||||||
#
|
#
|
||||||
return torch.cat((x_rope, x_pass), dim=-1)
|
return torch.cat((x_rope, x_pass), dim=-1)
|
||||||
|
Reference in New Issue
Block a user