This is an implementation of Rotary Positional Embeddings (RoPE) in PyTorch.
Rotary Positional Embeddings (RoPE) encode position information of tokens with a rotation matrix that naturally incorporates explicit relative position dependency.
Here's the training code for training a transformer model with RoPE on Tiny Shakespeare dataset.
23import torch
24from torch import nn
25
26from labml.logger import inspect
27from labml_nn.transformers.mha import MultiHeadAttentionRotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.
Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,
where is a constant angle. The other pairs of features are transformed similarly.
For a pair of features, dot-product attention score between two positions and would be
This shows that for dot-production attention the rotary encodings gives relative attention.
The features are grouped into pairs and handled as above. They use a different for each pair.
The paper suggests using for the pairs of features.
We pair feature with feature . So for position we transform
to
30class RotaryPositionalEmbeddings(nn.Module):d
 is the number of features  base
 is the constant used for calculating 117    def __init__(self, d: int, base: int = 10_000):122        super().__init__()
123
124        self.base = base
125        self.d = d
126        self.cos_cached = None
127        self.sin_cached = NoneCache and values
129    def _build_cache(self, x: torch.Tensor):Return if cache is already built
134        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
135            returnGet sequence length
138        seq_len = x.shape[0]141        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)Create position indexes [0, 1, ..., seq_len - 1]
 
144        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)Calculate the product of position index and
147        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)Concatenate so that for row we have
151        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)Cache them
154        self.cos_cached = idx_theta2.cos()[:, None, None, :]
155        self.sin_cached = idx_theta2.sin()[:, None, None, :]157    def _neg_half(self, x: torch.Tensor):159        d_2 = self.d // 2Calculate
162        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)x
 is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
164    def forward(self, x: torch.Tensor):Cache and values
169        self._build_cache(x)Split the features, we can choose to apply rotary embeddings only to a partial set of features.
172        x_rope, x_pass = x[..., :self.d], x[..., self.d:]Calculate
176        neg_half_x = self._neg_half(x_rope)188        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])191        return torch.cat((x_rope, x_pass), dim=-1)We override multi-head attention from original transformer.
194class RotaryPEMultiHeadAttention(MultiHeadAttention):201    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
202        super().__init__(heads, d_model, dropout_prob)Rotary positional embedding layers
205        d_rope = int(self.d_k * rope_percentage)
206        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
207        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)209    def get_scores(self, query: torch.Tensor, key: torch.Tensor):Calculate dot-product with RoPE
215        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))Testing RoPE with a simple example
218def _test_rotary():222    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
223    x = x[:, None, None, :]
224    inspect(x)
225
226    rotary_pe = RotaryPositionalEmbeddings(4)
227    inspect(rotary_pe(x))
228
229
230if __name__ == '__main__':
231    _test_rotary()