mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 01:13:00 +08:00
partial rope embeddings
This commit is contained in:
@ -19,3 +19,4 @@ indicators:
|
|||||||
name: optim.*
|
name: optim.*
|
||||||
options:
|
options:
|
||||||
comet: false
|
comet: false
|
||||||
|
web_api: http://localhost:5000/api/v1/track?
|
||||||
|
@ -163,8 +163,10 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|||||||
"""
|
"""
|
||||||
self._build_cache(x)
|
self._build_cache(x)
|
||||||
|
|
||||||
|
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
|
||||||
|
|
||||||
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||||
neg_half_x = self._neg_half(x)
|
neg_half_x = self._neg_half(x_rope)
|
||||||
|
|
||||||
# Calculate
|
# Calculate
|
||||||
#
|
#
|
||||||
@ -176,10 +178,10 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|||||||
# \end{align}
|
# \end{align}
|
||||||
#
|
#
|
||||||
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
||||||
rx = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
|
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
|
||||||
|
|
||||||
#
|
#
|
||||||
return rx
|
return torch.cat((x_rope, x_pass), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
class RotaryPEMultiHeadAttention(MultiHeadAttention):
|
class RotaryPEMultiHeadAttention(MultiHeadAttention):
|
||||||
@ -189,15 +191,16 @@ class RotaryPEMultiHeadAttention(MultiHeadAttention):
|
|||||||
We override [multi-head attention from original transformer](../mha.html).
|
We override [multi-head attention from original transformer](../mha.html).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
|
||||||
# The linear transformations do not need a bias since we
|
# The linear transformations do not need a bias since we
|
||||||
# explicitly include it when calculating scores.
|
# explicitly include it when calculating scores.
|
||||||
# However having a bias for `value` might make sense.
|
# However having a bias for `value` might make sense.
|
||||||
super().__init__(heads, d_model, dropout_prob, bias=False)
|
super().__init__(heads, d_model, dropout_prob, bias=False)
|
||||||
|
|
||||||
# Rotary positional embedding layers
|
# Rotary positional embedding layers
|
||||||
self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
d_rope = int(self.d_k * rope_percentage)
|
||||||
self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||||
|
self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||||
|
|
||||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
|
@ -20,7 +20,7 @@ from labml_nn.transformers.basic.autoregressive_experiment import Autoregressive
|
|||||||
# ### Rotary PE attention
|
# ### Rotary PE attention
|
||||||
def _rotary_pe_mha(c: TransformerConfigs):
|
def _rotary_pe_mha(c: TransformerConfigs):
|
||||||
from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
|
from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
|
||||||
return RotaryPEMultiHeadAttention(c.n_heads, c.d_model)
|
return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 0.5)
|
||||||
|
|
||||||
|
|
||||||
# Configuration options
|
# Configuration options
|
||||||
|
@ -48,8 +48,10 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
|
|||||||
"""
|
"""
|
||||||
self._build_cache(x)
|
self._build_cache(x)
|
||||||
|
|
||||||
|
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
|
||||||
|
|
||||||
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||||
neg_half_x = self._neg_half(x)
|
neg_half_x = self._neg_half(x_rope)
|
||||||
|
|
||||||
# Calculate
|
# Calculate
|
||||||
#
|
#
|
||||||
@ -65,10 +67,10 @@ class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
|
|||||||
# \end{align}
|
# \end{align}
|
||||||
#
|
#
|
||||||
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
||||||
rx = (x * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
|
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
|
||||||
|
|
||||||
#
|
#
|
||||||
return rx
|
return torch.cat((x_rope, x_pass), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
|
class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
|
||||||
@ -78,17 +80,18 @@ class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
|
|||||||
We override [multi-head attention from original transformer](../mha.html).
|
We override [multi-head attention from original transformer](../mha.html).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.1):
|
||||||
# The linear transformations do not need a bias since we
|
# The linear transformations do not need a bias since we
|
||||||
# explicitly include it when calculating scores.
|
# explicitly include it when calculating scores.
|
||||||
# However having a bias for `value` might make sense.
|
# However having a bias for `value` might make sense.
|
||||||
super().__init__(heads, d_model, dropout_prob, bias=False)
|
super().__init__(heads, d_model, dropout_prob, bias=False)
|
||||||
|
|
||||||
# Rotary positional embedding layers
|
# Rotary positional embedding layers
|
||||||
self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
d_rope = int(self.d_k * rope_percentage)
|
||||||
self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||||
self.value_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||||
self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(self.d_k)
|
self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||||
|
self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope)
|
||||||
|
|
||||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
|
@ -13,16 +13,19 @@ This is an annotated PyTorch experiment to train a transformer model with Rotary
|
|||||||
|
|
||||||
from labml import experiment
|
from labml import experiment
|
||||||
from labml.configs import calculate
|
from labml.configs import calculate
|
||||||
|
from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression
|
||||||
from labml_nn.transformers import TransformerConfigs
|
from labml_nn.transformers import TransformerConfigs
|
||||||
from labml_nn.transformers.rope.experiment import Configs
|
from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
|
||||||
|
|
||||||
|
|
||||||
# ### Rotary PE attention
|
# ### Rotary PE attention
|
||||||
|
|
||||||
|
class Configs(RoPEConfigs): # , ArithmeticAutoregression):
|
||||||
|
pass
|
||||||
|
|
||||||
def _rotary_value_pe_mha(c: TransformerConfigs):
|
def _rotary_value_pe_mha(c: TransformerConfigs):
|
||||||
from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
|
from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
|
||||||
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model)
|
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 0.5)
|
||||||
|
|
||||||
|
|
||||||
# Configuration options
|
# Configuration options
|
||||||
@ -33,7 +36,7 @@ calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create experiment
|
# Create experiment
|
||||||
experiment.create(name="rotary_pe_transformer", writers={'screen'})
|
experiment.create(name="rotary_pe_transformer", writers={'screen', 'labml'})
|
||||||
# Create configs
|
# Create configs
|
||||||
conf = Configs()
|
conf = Configs()
|
||||||
# Override configurations
|
# Override configurations
|
||||||
@ -43,8 +46,8 @@ def main():
|
|||||||
'transformer.tgt_embed': 'no_pos',
|
'transformer.tgt_embed': 'no_pos',
|
||||||
|
|
||||||
# Encoder with RoPE
|
# Encoder with RoPE
|
||||||
'transformer.encoder_attn': 'rotary_value',
|
|
||||||
# 'transformer.encoder_attn': 'rotary_value',
|
# 'transformer.encoder_attn': 'rotary_value',
|
||||||
|
'transformer.encoder_attn': 'rotary',
|
||||||
|
|
||||||
#
|
#
|
||||||
'model': 'rotary_pe_transformer',
|
'model': 'rotary_pe_transformer',
|
||||||
|
Reference in New Issue
Block a user