This is an implementation of Rotary Positional Embeddings (RoPE) in PyTorch.
Rotary Positional Embeddings (RoPE) encode position information of tokens with a rotation matrix that naturally incorporates explicit relative position dependency.
Here's the training code for training a transformer model with RoPE on Tiny Shakespeare dataset.
23import torch
24from torch import nn
25
26from labml.logger import inspect
27from labml_nn.transformers.mha import MultiHeadAttention
Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.
Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,
where is a constant angle. The other pairs of features are transformed similarly.
For a pair of features, dot-product attention score between two positions and would be
This shows that for dot-production attention the rotary encodings gives relative attention.
The features are grouped into pairs and handled as above. They use a different for each pair.
The paper suggests using for the pairs of features.
We pair feature with feature . So for position we transform
to
30class RotaryPositionalEmbeddings(nn.Module):
d
is the number of features base
is the constant used for calculating 117 def __init__(self, d: int, base: int = 10_000):
122 super().__init__()
123
124 self.base = base
125 self.d = d
126 self.cos_cached = None
127 self.sin_cached = None
Cache and values
129 def _build_cache(self, x: torch.Tensor):
Return if cache is already built
134 if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
135 return
Get sequence length
138 seq_len = x.shape[0]
141 theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
Create position indexes [0, 1, ..., seq_len - 1]
144 seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
Calculate the product of position index and
147 idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
Concatenate so that for row we have
151 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
Cache them
154 self.cos_cached = idx_theta2.cos()[:, None, None, :]
155 self.sin_cached = idx_theta2.sin()[:, None, None, :]
157 def _neg_half(self, x: torch.Tensor):
159 d_2 = self.d // 2
Calculate
162 return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
x
is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
164 def forward(self, x: torch.Tensor):
Cache and values
169 self._build_cache(x)
Sequence length
172 seq_len = x.shape[0]
Split the features, we can choose to apply rotary embeddings only to a partial set of features.
175 x_rope, x_pass = x[..., :self.d], x[..., self.d:]
Calculate
179 neg_half_x = self._neg_half(x_rope)
191 x_rope = (x_rope * self.cos_cached[:seq_len]) + (neg_half_x * self.sin_cached[:seq_len])
194 return torch.cat((x_rope, x_pass), dim=-1)
We override multi-head attention from original transformer.
197class RotaryPEMultiHeadAttention(MultiHeadAttention):
204 def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
205 super().__init__(heads, d_model, dropout_prob)
Rotary positional embedding layers
208 d_rope = int(self.d_k * rope_percentage)
209 self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
210 self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
212 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
Calculate dot-product with RoPE
218 return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
Testing RoPE with a simple example
221def _test_rotary():
225 x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
226 x = x[:, None, None, :]
227 inspect(x)
228
229 rotary_pe = RotaryPositionalEmbeddings(4)
230 inspect(rotary_pe(x))
231
232
233if __name__ == '__main__':
234 _test_rotary()