mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 01:13:00 +08:00
partial rope embeddings
This commit is contained in:
@ -19,3 +19,4 @@ indicators:
|
||||
name: optim.*
|
||||
options:
|
||||
comet: false
|
||||
web_api: http://localhost:5000/api/v1/track?
|
||||
|
@ -163,8 +163,10 @@ class RotaryPositionalEmbeddings(nn.Module):
|
||||
"""
|
||||
self._build_cache(x)
|
||||
|
||||
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
|
||||
|
||||
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
neg_half_x = self._neg_half(x)
|
||||
neg_half_x = self._neg_half(x_rope)
|
||||
|
||||
# Calculate
|
||||
#
|
||||
@ -176,10 +178,10 @@ class RotaryPositionalEmbeddings(nn.Module):
|
||||
# \end{align}
|
||||
#
|
||||
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
||||
rx = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
|
||||
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
|
||||
|
||||
#
|
||||
return rx
|
||||
return torch.cat((x_rope, x_pass), dim=-1)
|
||||
|
||||
|
||||
class RotaryPEMultiHeadAttention(MultiHeadAttention):
|
||||
@ -189,15 +191,16 @@ class RotaryPEMultiHeadAttention(MultiHeadAttention):
|
||||
We override [multi-head attention from original transformer](../mha.html).
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||
def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
|
||||
# The linear transformations do not need a bias since we
|
||||
# explicitly include it when calculating scores.
|
||||
# However having a bias for `value` might make sense.
|
||||
super().__init__(heads, d_model, dropout_prob, bias=False)
|
||||
|
||||
# Rotary positional embedding layers
|
||||
self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
||||
self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
||||
d_rope = int(self.d_k * rope_percentage)
|
||||
self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||
self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||
|
||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||
"""
|
||||
|
@ -20,7 +20,7 @@ from labml_nn.transformers.basic.autoregressive_experiment import Autoregressive
|
||||
# ### Rotary PE attention
|
||||
def _rotary_pe_mha(c: TransformerConfigs):
|
||||
from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
|
||||
return RotaryPEMultiHeadAttention(c.n_heads, c.d_model)
|
||||
return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 0.5)
|
||||
|
||||
|
||||
# Configuration options
|
||||
|
@ -48,8 +48,10 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
|
||||
"""
|
||||
self._build_cache(x)
|
||||
|
||||
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
|
||||
|
||||
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
neg_half_x = self._neg_half(x)
|
||||
neg_half_x = self._neg_half(x_rope)
|
||||
|
||||
# Calculate
|
||||
#
|
||||
@ -65,10 +67,10 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
|
||||
# \end{align}
|
||||
#
|
||||
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
||||
rx = (x * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
|
||||
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
|
||||
|
||||
#
|
||||
return rx
|
||||
return torch.cat((x_rope, x_pass), dim=-1)
|
||||
|
||||
|
||||
class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
|
||||
@ -78,17 +80,18 @@ class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
|
||||
We override [multi-head attention from original transformer](../mha.html).
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||
def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
|
||||
# The linear transformations do not need a bias since we
|
||||
# explicitly include it when calculating scores.
|
||||
# However having a bias for `value` might make sense.
|
||||
super().__init__(heads, d_model, dropout_prob, bias=False)
|
||||
|
||||
# Rotary positional embedding layers
|
||||
self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
||||
self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
||||
self.value_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
||||
self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(self.d_k)
|
||||
d_rope = int(self.d_k * rope_percentage)
|
||||
self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||
self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||
self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||
self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope)
|
||||
|
||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||
"""
|
||||
|
@ -13,16 +13,19 @@ This is an annotated PyTorch experiment to train a transformer model with Rotary
|
||||
|
||||
from labml import experiment
|
||||
from labml.configs import calculate
|
||||
from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression
|
||||
from labml_nn.transformers import TransformerConfigs
|
||||
from labml_nn.transformers.rope.experiment import Configs
|
||||
from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
|
||||
|
||||
|
||||
# ### Rotary PE attention
|
||||
|
||||
class Configs(RoPEConfigs): # , ArithmeticAutoregression):
|
||||
pass
|
||||
|
||||
def _rotary_value_pe_mha(c: TransformerConfigs):
|
||||
from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
|
||||
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model)
|
||||
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 0.5)
|
||||
|
||||
|
||||
# Configuration options
|
||||
@ -33,7 +36,7 @@ calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name="rotary_pe_transformer", writers={'screen'})
|
||||
experiment.create(name="rotary_pe_transformer", writers={'screen', 'labml'})
|
||||
# Create configs
|
||||
conf = Configs()
|
||||
# Override configurations
|
||||
@ -43,8 +46,8 @@ def main():
|
||||
'transformer.tgt_embed': 'no_pos',
|
||||
|
||||
# Encoder with RoPE
|
||||
'transformer.encoder_attn': 'rotary_value',
|
||||
# 'transformer.encoder_attn': 'rotary_value',
|
||||
'transformer.encoder_attn': 'rotary',
|
||||
|
||||
#
|
||||
'model': 'rotary_pe_transformer',
|
||||
|
Reference in New Issue
Block a user