This is an implementation of Rotary Positional Embeddings (RoPE) in PyTorch.
Rotary Positional Embeddings (RoPE) encode position information of tokens with a rotation matrix that naturally incorporates explicit relative position dependency.
Here's the training code for training a transformer model with RoPE on Tiny Shakespeare dataset.
25import torch
26from torch import nn
27
28from labml.logger import inspect
29from labml_nn.transformers.mha import MultiHeadAttentionRotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.
Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,
where is a constant angle. The other pairs of features are transformed similarly.
For a pair of features, dot-product attention score between two positions and would be
This shows that for dot-production attention the rotary encodings gives relative attention.
The features are grouped into pairs and handled as above. They use a different for each pair.
The paper suggests using for the pairs of features.
We pair feature with feature . So for position we transform
to
32class RotaryPositionalEmbeddings(nn.Module):d
 is the number of features  base
 is the constant used for calculating 118    def __init__(self, d: int, base: int = 10_000):123        super().__init__()125        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)x
 is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d]
127    def forward(self, x: torch.Tensor):Extract the shape
132        seq_len, batch_size, n_heads, d = x.shape135        d_2 = d // 2Create position indexes [0, 1, ..., seq_len - 1]
 
138        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)Calculate the product of position index and
141        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)Concatenate so that for row we have
145        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)Calculate
148        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)160        rx = (x * idx_theta2.cos()[:, None, None, :]) + (neg_half_x * idx_theta2.sin()[:, None, None, :])163        return rxWe override multi-head attention from original transformer.
166class RotaryPEMultiHeadAttention(MultiHeadAttention):173    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value
 might make sense. 
177        super().__init__(heads, d_model, dropout_prob, bias=False)Rotary positional embedding layers
180        self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
181        self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k)183    def get_scores(self, query: torch.Tensor, key: torch.Tensor):Calculate dot-product with RoPE
189        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))Testing RoPE with a simple example
192def _test_rotary():196    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
197    x = x[:, None, None, :]
198    inspect(x)
199
200    rotary_pe = RotaryPositionalEmbeddings(3)
201    inspect(rotary_pe(x))
202
203
204if __name__ == '__main__':
205    _test_rotary()