partial rope embeddings

This commit is contained in:
Varuna Jayasiri
2022-05-31 11:51:03 +05:30
parent e56ea23c80
commit 10f44d0b21
5 changed files with 29 additions and 19 deletions

View File

@ -19,3 +19,4 @@ indicators:
name: optim.* name: optim.*
options: options:
comet: false comet: false
web_api: http://localhost:5000/api/v1/track?

View File

@ -163,8 +163,10 @@ class RotaryPositionalEmbeddings(nn.Module):
""" """
self._build_cache(x) self._build_cache(x)
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x) neg_half_x = self._neg_half(x_rope)
# Calculate # Calculate
# #
@ -176,10 +178,10 @@ class RotaryPositionalEmbeddings(nn.Module):
# \end{align} # \end{align}
# #
# for $i \in {1, 2, ..., \frac{d}{2}}$ # for $i \in {1, 2, ..., \frac{d}{2}}$
rx = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
# #
return rx return torch.cat((x_rope, x_pass), dim=-1)
class RotaryPEMultiHeadAttention(MultiHeadAttention): class RotaryPEMultiHeadAttention(MultiHeadAttention):
@ -189,15 +191,16 @@ class RotaryPEMultiHeadAttention(MultiHeadAttention):
We override [multi-head attention from original transformer](../mha.html). We override [multi-head attention from original transformer](../mha.html).
""" """
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
# The linear transformations do not need a bias since we # The linear transformations do not need a bias since we
# explicitly include it when calculating scores. # explicitly include it when calculating scores.
# However having a bias for `value` might make sense. # However having a bias for `value` might make sense.
super().__init__(heads, d_model, dropout_prob, bias=False) super().__init__(heads, d_model, dropout_prob, bias=False)
# Rotary positional embedding layers # Rotary positional embedding layers
self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k) d_rope = int(self.d_k * rope_percentage)
self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k) self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
def get_scores(self, query: torch.Tensor, key: torch.Tensor): def get_scores(self, query: torch.Tensor, key: torch.Tensor):
""" """

View File

@ -20,7 +20,7 @@ from labml_nn.transformers.basic.autoregressive_experiment import Autoregressive
# ### Rotary PE attention # ### Rotary PE attention
def _rotary_pe_mha(c: TransformerConfigs): def _rotary_pe_mha(c: TransformerConfigs):
from labml_nn.transformers.rope import RotaryPEMultiHeadAttention from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
return RotaryPEMultiHeadAttention(c.n_heads, c.d_model) return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 0.5)
# Configuration options # Configuration options

View File

@ -48,8 +48,10 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
""" """
self._build_cache(x) self._build_cache(x)
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x) neg_half_x = self._neg_half(x_rope)
# Calculate # Calculate
# #
@ -65,10 +67,10 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
# \end{align} # \end{align}
# #
# for $i \in {1, 2, ..., \frac{d}{2}}$ # for $i \in {1, 2, ..., \frac{d}{2}}$
rx = (x * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]]) x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
# #
return rx return torch.cat((x_rope, x_pass), dim=-1)
class RotaryValuePEMultiHeadAttention(MultiHeadAttention): class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
@ -78,17 +80,18 @@ class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
We override [multi-head attention from original transformer](../mha.html). We override [multi-head attention from original transformer](../mha.html).
""" """
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
# The linear transformations do not need a bias since we # The linear transformations do not need a bias since we
# explicitly include it when calculating scores. # explicitly include it when calculating scores.
# However having a bias for `value` might make sense. # However having a bias for `value` might make sense.
super().__init__(heads, d_model, dropout_prob, bias=False) super().__init__(heads, d_model, dropout_prob, bias=False)
# Rotary positional embedding layers # Rotary positional embedding layers
self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k) d_rope = int(self.d_k * rope_percentage)
self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k) self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
self.value_rotary_pe = RotaryPositionalEmbeddings(self.d_k) self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(self.d_k) self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope)
self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope)
def get_scores(self, query: torch.Tensor, key: torch.Tensor): def get_scores(self, query: torch.Tensor, key: torch.Tensor):
""" """

View File

@ -13,16 +13,19 @@ This is an annotated PyTorch experiment to train a transformer model with Rotary
from labml import experiment from labml import experiment
from labml.configs import calculate from labml.configs import calculate
from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression
from labml_nn.transformers import TransformerConfigs from labml_nn.transformers import TransformerConfigs
from labml_nn.transformers.rope.experiment import Configs from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
# ### Rotary PE attention # ### Rotary PE attention
class Configs(RoPEConfigs): # , ArithmeticAutoregression):
pass
def _rotary_value_pe_mha(c: TransformerConfigs): def _rotary_value_pe_mha(c: TransformerConfigs):
from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model) return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 0.5)
# Configuration options # Configuration options
@ -33,7 +36,7 @@ calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_
def main(): def main():
# Create experiment # Create experiment
experiment.create(name="rotary_pe_transformer", writers={'screen'}) experiment.create(name="rotary_pe_transformer", writers={'screen', 'labml'})
# Create configs # Create configs
conf = Configs() conf = Configs()
# Override configurations # Override configurations
@ -43,8 +46,8 @@ def main():
'transformer.tgt_embed': 'no_pos', 'transformer.tgt_embed': 'no_pos',
# Encoder with RoPE # Encoder with RoPE
'transformer.encoder_attn': 'rotary_value',
# 'transformer.encoder_attn': 'rotary_value', # 'transformer.encoder_attn': 'rotary_value',
'transformer.encoder_attn': 'rotary',
# #
'model': 'rotary_pe_transformer', 'model': 'rotary_pe_transformer',