From 10f44d0b21da65e2d3da51bb57a368e9f54ceb82 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 31 May 2022 11:51:03 +0530 Subject: [PATCH] partial rope embeddings --- .labml.yaml | 1 + labml_nn/transformers/rope/__init__.py | 15 +++++++++------ labml_nn/transformers/rope/experiment.py | 2 +- .../transformers/rope/value_pe/__init__.py | 19 +++++++++++-------- .../transformers/rope/value_pe/experiment.py | 11 +++++++---- 5 files changed, 29 insertions(+), 19 deletions(-) diff --git a/.labml.yaml b/.labml.yaml index 1290b7bf..5578d582 100644 --- a/.labml.yaml +++ b/.labml.yaml @@ -19,3 +19,4 @@ indicators: name: optim.* options: comet: false +web_api: http://localhost:5000/api/v1/track? diff --git a/labml_nn/transformers/rope/__init__.py b/labml_nn/transformers/rope/__init__.py index bbbc5aa1..2afd077b 100644 --- a/labml_nn/transformers/rope/__init__.py +++ b/labml_nn/transformers/rope/__init__.py @@ -163,8 +163,10 @@ class RotaryPositionalEmbeddings(nn.Module): """ self._build_cache(x) + x_rope, x_pass = x[..., :self.d], x[..., self.d:] + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ - neg_half_x = self._neg_half(x) + neg_half_x = self._neg_half(x_rope) # Calculate # @@ -176,10 +178,10 @@ class RotaryPositionalEmbeddings(nn.Module): # \end{align} # # for $i \in {1, 2, ..., \frac{d}{2}}$ - rx = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) + x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) # - return rx + return torch.cat((x_rope, x_pass), dim=-1) class RotaryPEMultiHeadAttention(MultiHeadAttention): @@ -189,15 +191,16 @@ class RotaryPEMultiHeadAttention(MultiHeadAttention): We override [multi-head attention from original transformer](../mha.html). """ - def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1): # The linear transformations do not need a bias since we # explicitly include it when calculating scores. # However having a bias for `value` might make sense. super().__init__(heads, d_model, dropout_prob, bias=False) # Rotary positional embedding layers - self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k) - self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k) + d_rope = int(self.d_k * rope_percentage) + self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope) + self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope) def get_scores(self, query: torch.Tensor, key: torch.Tensor): """ diff --git a/labml_nn/transformers/rope/experiment.py b/labml_nn/transformers/rope/experiment.py index 3ebc65c9..b9aa3c5e 100644 --- a/labml_nn/transformers/rope/experiment.py +++ b/labml_nn/transformers/rope/experiment.py @@ -20,7 +20,7 @@ from labml_nn.transformers.basic.autoregressive_experiment import Autoregressive # ### Rotary PE attention def _rotary_pe_mha(c: TransformerConfigs): from labml_nn.transformers.rope import RotaryPEMultiHeadAttention - return RotaryPEMultiHeadAttention(c.n_heads, c.d_model) + return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 0.5) # Configuration options diff --git a/labml_nn/transformers/rope/value_pe/__init__.py b/labml_nn/transformers/rope/value_pe/__init__.py index 1305e222..b2064dc1 100644 --- a/labml_nn/transformers/rope/value_pe/__init__.py +++ b/labml_nn/transformers/rope/value_pe/__init__.py @@ -48,8 +48,10 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings): """ self._build_cache(x) + x_rope, x_pass = x[..., :self.d], x[..., self.d:] + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ - neg_half_x = self._neg_half(x) + neg_half_x = self._neg_half(x_rope) # Calculate # @@ -65,10 +67,10 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings): # \end{align} # # for $i \in {1, 2, ..., \frac{d}{2}}$ - rx = (x * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]]) + x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]]) # - return rx + return torch.cat((x_rope, x_pass), dim=-1) class RotaryValuePEMultiHeadAttention(MultiHeadAttention): @@ -78,17 +80,18 @@ class RotaryValuePEMultiHeadAttention(MultiHeadAttention): We override [multi-head attention from original transformer](../mha.html). """ - def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1): # The linear transformations do not need a bias since we # explicitly include it when calculating scores. # However having a bias for `value` might make sense. super().__init__(heads, d_model, dropout_prob, bias=False) # Rotary positional embedding layers - self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k) - self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k) - self.value_rotary_pe = RotaryPositionalEmbeddings(self.d_k) - self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(self.d_k) + d_rope = int(self.d_k * rope_percentage) + self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope) + self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope) + self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope) + self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope) def get_scores(self, query: torch.Tensor, key: torch.Tensor): """ diff --git a/labml_nn/transformers/rope/value_pe/experiment.py b/labml_nn/transformers/rope/value_pe/experiment.py index 4b29658c..791efbed 100644 --- a/labml_nn/transformers/rope/value_pe/experiment.py +++ b/labml_nn/transformers/rope/value_pe/experiment.py @@ -13,16 +13,19 @@ This is an annotated PyTorch experiment to train a transformer model with Rotary from labml import experiment from labml.configs import calculate +from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression from labml_nn.transformers import TransformerConfigs -from labml_nn.transformers.rope.experiment import Configs +from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs # ### Rotary PE attention +class Configs(RoPEConfigs): # , ArithmeticAutoregression): + pass def _rotary_value_pe_mha(c: TransformerConfigs): from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention - return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model) + return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 0.5) # Configuration options @@ -33,7 +36,7 @@ calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_ def main(): # Create experiment - experiment.create(name="rotary_pe_transformer", writers={'screen'}) + experiment.create(name="rotary_pe_transformer", writers={'screen', 'labml'}) # Create configs conf = Configs() # Override configurations @@ -43,8 +46,8 @@ def main(): 'transformer.tgt_embed': 'no_pos', # Encoder with RoPE - 'transformer.encoder_attn': 'rotary_value', # 'transformer.encoder_attn': 'rotary_value', + 'transformer.encoder_attn': 'rotary', # 'model': 'rotary_pe_transformer',