This is an implementation of Rotary Positional Embeddings (RoPE) in PyTorch.
Rotary Positional Embeddings (RoPE) encode position information of tokens with a rotation matrix that naturally incorporates explicit relative position dependency.
Here's the training code for training a transformer model with RoPE on Tiny Shakespeare dataset.
25import torch
26from torch import nn
27
28from labml.logger import inspect
29from labml_nn.transformers.mha import MultiHeadAttentionRotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.
Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,
where is a constant angle. The other pairs of features are transformed similarly.
For a pair of features, dot-product attention score between two positions and would be
This shows that for dot-production attention the rotary encodings gives relative attention.
The features are grouped into pairs and handled as above. They use a different for each pair.
The paper suggests using for the pairs of features.
We pair feature with feature . So for position we transform
to
32class RotaryPositionalEmbeddings(nn.Module):d
 is the number of features  base
 is the constant used for calculating 119    def __init__(self, d: int, base: int = 10_000):124        super().__init__()
125
126        self.base = base
127        self.d = d
128        self.cos_cached = None
129        self.sin_cached = NoneCache and values
131    def _build_cache(self, x: torch.Tensor):Return if cache is already built
136        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
137            returnGet sequence length
140        seq_len = x.shape[0]143        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)Create position indexes [0, 1, ..., seq_len - 1]
 
146        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)Calculate the product of position index and
149        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)Concatenate so that for row we have
153        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)Cache them
156        self.cos_cached = idx_theta2.cos()[:, None, None, :]
157        self.sin_cached = idx_theta2.sin()[:, None, None, :]159    def _neg_half(self, x: torch.Tensor):161        d_2 = self.d // 2Calculate
164        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)x
 is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
166    def forward(self, x: torch.Tensor):Cache and values
171        self._build_cache(x)Split the features, we can choose to apply rotary embeddings only to a partial set of features.
174        x_rope, x_pass = x[..., :self.d], x[..., self.d:]Calculate
178        neg_half_x = self._neg_half(x_rope)190        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])193        return torch.cat((x_rope, x_pass), dim=-1)We override multi-head attention from original transformer.
196class RotaryPEMultiHeadAttention(MultiHeadAttention):203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
204        super().__init__(heads, d_model, dropout_prob)Rotary positional embedding layers
207        d_rope = int(self.d_k * rope_percentage)
208        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
209        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)211    def get_scores(self, query: torch.Tensor, key: torch.Tensor):Calculate dot-product with RoPE
217        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))Testing RoPE with a simple example
220def _test_rotary():224    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
225    x = x[:, None, None, :]
226    inspect(x)
227
228    rotary_pe = RotaryPositionalEmbeddings(3)
229    inspect(rotary_pe(x))
230
231
232if __name__ == '__main__':
233    _test_rotary()