From 0ce65adf9e602321109528b05cf99fccb16cd2de Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Fri, 3 Jun 2022 21:29:41 +0530 Subject: [PATCH] RoPER (#126) --- .gitignore | 3 +- .labml.yaml | 1 + docs/experiments/arithmetic_dataset.html | 900 ++++++++++++++++++ docs/normalization/deep_norm/experiment.html | 3 +- docs/sitemap.xml | 34 +- docs/transformers/rope/experiment.html | 4 +- docs/transformers/rope/index.html | 217 +++-- .../rope/value_pe/arithmetic_experiment.html | 403 ++++++++ .../rope/value_pe/experiment.html | 454 +++++++++ docs/transformers/rope/value_pe/index.html | 510 ++++++++++ labml_nn/experiments/arithmetic_dataset.py | 260 +++++ labml_nn/transformers/rope/__init__.py | 76 +- labml_nn/transformers/rope/experiment.py | 4 +- .../transformers/rope/value_pe/__init__.py | 246 +++++ .../rope/value_pe/arithmetic_experiment.py | 93 ++ .../transformers/rope/value_pe/experiment.py | 98 ++ setup.py | 6 +- 17 files changed, 3207 insertions(+), 105 deletions(-) create mode 100644 docs/experiments/arithmetic_dataset.html create mode 100644 docs/transformers/rope/value_pe/arithmetic_experiment.html create mode 100644 docs/transformers/rope/value_pe/experiment.html create mode 100644 docs/transformers/rope/value_pe/index.html create mode 100644 labml_nn/experiments/arithmetic_dataset.py create mode 100644 labml_nn/transformers/rope/value_pe/__init__.py create mode 100644 labml_nn/transformers/rope/value_pe/arithmetic_experiment.py create mode 100644 labml_nn/transformers/rope/value_pe/experiment.py diff --git a/.gitignore b/.gitignore index 6c0eec71..4336d66b 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ logs html/ diagrams/ .comet.config -settings.md \ No newline at end of file +settings.md +labml_app.log \ No newline at end of file diff --git a/.labml.yaml b/.labml.yaml index 1290b7bf..2e384051 100644 --- a/.labml.yaml +++ b/.labml.yaml @@ -19,3 +19,4 @@ indicators: name: optim.* options: comet: false +web_api: http://localhost:5005/api/v1/track? diff --git a/docs/experiments/arithmetic_dataset.html b/docs/experiments/arithmetic_dataset.html new file mode 100644 index 00000000..c35bbd9a --- /dev/null +++ b/docs/experiments/arithmetic_dataset.html @@ -0,0 +1,900 @@ + + + + + + + + + + + + + + + + + + + + + + + Arithmetic Dataset + + + + + + + + + + +
+
+
+
+

+ home + experiments +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

This is based on code by Georges Harik (@gharik).

+ +
+
+
11import random
+12import string
+13from typing import List
+14
+15import torch
+16from labml.logger import Text
+17from torch.utils.data import DataLoader, Dataset
+18
+19from labml import monit, logger, tracker
+20from labml.configs import option
+21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
+
+
+
+
+ +

Arithmetic Dataset

+

This creates arithmetic addition problems and solutions with workings. We've only implemented addition so far.

+

It's based on a character level tokenization.

+ +
+
+
24class ArithmeticDataset(Dataset):
+
+
+
+
+ +
  • seq_len is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch
+ +
+
+
34    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
+
+
+
+
+ + +
+
+
41        self.n_sequences = n_sequences
+42        self.max_digits = max_digits
+43        self.seq_len = seq_len
+
+
+
+
+ +

Token id to string

+ +
+
+
45        self.itos = list(string.digits + 'xe =\n?+;')
+
+
+
+
+ +

Character to token id

+ +
+
+
47        self.stoi = {c: i for i, c in enumerate(self.itos)}
+
+
+
+
+ +

Generates an integer with n_digit + number of digits

+ +
+
+
49    @staticmethod
+50    def make_int(n_digits: int):
+
+
+
+
+ + +
+
+
54        res = 0
+55        for i in range(n_digits):
+56            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
+57            res = res * 10 + d
+58
+59        return res
+
+
+
+
+ +

Generates the workings for x + y +. For example for 11+29 + it generates 1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0 +.

+ +
+
+
61    @staticmethod
+62    def get_add_explanation(x: int, y: int):
+
+
+
+
+ + +
+
+
69        carry = 0
+70        e = 0
+71        explanation = []
+72        while x > 0 or y > 0 or carry > 0:
+73            rx, ry = x % 10, y % 10
+74            total = rx + ry + carry
+75            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
+76            x, y, carry = x // 10, y // 10, total // 10
+77            e += 1
+78
+79        return ' '.join(explanation)
+
+
+
+
+ +

Make a problem with a pre_explanation or not

+

Creates an arithmetic addition problem with workings and answer.

+ +
+
+
82    def make_add_problem(self):
+
+
+
+
+ + +
+
+
86        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+87        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+88
+89        explanation = self.get_add_explanation(x, y)
+90        return f"x={x}+{y}; {explanation} x=={x + y}\n"
+
+
+
+
+ +

Get arithmetic problem and answer. This is used for evaluation.

+ +
+
+
92    def get_qa(self):
+
+
+
+
+ + +
+
+
96        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+97        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
+98
+99        return f'x={x}+{y};', f'{x + y}'
+
+
+
+
+ +

Generate multiple problems and pack them into a sequence.

+ +
+
+
101    def get_packed_math_input(self):
+
+
+
+
+ + +
+
+
105        s_enc = []
+106        while len(s_enc) <= self.seq_len:
+107            s_part = self.make_add_problem()
+108            s_part_enc = self.encode('?' + s_part)
+109            s_enc = s_enc + s_part_enc
+110        return s_enc
+
+
+
+
+ +

Encode a given string

+ +
+
+
112    def encode(self, s: str):
+
+
+
+
+ + +
+
+
116        return [self.stoi[c] for c in s]
+
+
+
+
+ +

Decode a list of token ids

+ +
+
+
118    def decode(self, arr: List[int]):
+
+
+
+
+ + +
+
+
122        return ''.join([self.itos[c] for c in arr])
+
+
+
+
+ +

Get a input and target pair for auto-regressive modelling

+ +
+
+
124    def __getitem__(self, idx: int):
+
+
+
+
+ + +
+
+
128        s = torch.tensor(self.get_packed_math_input())
+129        return s[:self.seq_len], s[1:self.seq_len + 1]
+
+
+
+
+ +

Number of sequences per epoch

+ +
+
+
131    def __len__(self):
+
+
+
+
+ + +
+
+
135        return self.n_sequences
+
+
+
+
+ +

Arithmetic Task Experiment Configurations

+ +
+
+
138class ArithmeticAutoregression(NLPAutoRegressionConfigs):
+
+
+
+
+ +

Maximum number of digits per operand integer

+ +
+
+
143    max_digits: int = 4
+
+
+
+
+ +

Number of training sequences per epoch

+ +
+
+
145    train_sequences_per_epoch: int = 2 ** 12
+
+
+
+
+ +

Training data loader

+ +
+
+
147    train_loader: DataLoader = 'arithmetic_train_loader'
+
+
+
+
+ +

Number of problems in evaluation

+ +
+
+
149    n_tests: int = 64
+
+
+
+
+ +

No need of a validation dataset

+ +
+
+
151    validator = None
+
+
+
+
+ +

Number of times to run evaluations per epoch

+ +
+
+
153    inner_iterations = 4
+
+
+
+
+ +

Number of tokens in the vocabulary

+ +
+
+
155    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
+
+
+
+
+ +

Evaluation

+

We use the sampling function to evaluate the model on a set of problems

+ +
+
+
157    @torch.no_grad()
+158    def sample(self):
+
+
+
+
+ +

Skip in the first epoch

+ +
+
+
166        if self.training_loop.idx < 1:
+167            return
+
+
+
+
+ +

Create a dataset to generate problems

+ +
+
+
170        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
+
+
+
+
+ +

Get a set of problems and answers

+ +
+
+
172        qa = [dataset.get_qa() for _ in range(self.n_tests)]
+
+
+
+
+ +

Collect the problems only

+ +
+
+
174        questions = [p[0] for p in qa]
+
+
+
+
+ +

Create a tensor with only the initial token

+ +
+
+
177        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
+
+
+
+
+ +

Move to device

+ +
+
+
179        data = data.to(self.device)
+
+
+
+
+ +

Number of sequences that have completed

+ +
+
+
182        finished = torch.zeros((len(questions),)).bool().to(self.device)
+
+
+
+
+ +

Token id of the new line character - this marks end of the answer

+ +
+
+
184        new_line = dataset.stoi['\n']
+
+
+
+
+ +

Sampled results

+ +
+
+
187        results = [p[0] for p in questions]
+
+
+
+
+ +

Sample upto sequence length

+ +
+
+
190        for i in monit.iterate('Sample', self.seq_len - 1):
+
+
+
+
+ +

If all the sequences have completed we skip this

+ +
+
+
192            if finished.sum() == len(finished):
+193                continue
+
+
+
+
+ +

Get the model output

+ +
+
+
196            output, *_ = self.model(data)
+
+
+
+
+ +

Get the model prediction (greedy)

+ +
+
+
198            output = output[-1].argmax(dim=-1)
+
+
+
+
+ +

Find which sequences have finished

+ +
+
+
201            finished = finished | (output == new_line)
+
+
+
+
+ +

Skip if all have finished

+ +
+
+
203            if finished.sum() == len(finished):
+204                continue
+
+
+
+
+ +

Override with the question

+ +
+
+
207            for j, p in enumerate(questions):
+208                if len(p) > i + 1:
+209                    output[j] = dataset.stoi[p[i + 1]]
+
+
+
+
+ +

Add the next token to the input

+ +
+
+
212            data = torch.cat([data, output[None, :]], dim=0)
+
+
+
+
+ +

Get the sampled results

+ +
+
+
215            for j, c in enumerate(output):
+216                results[j] += dataset.itos[c]
+
+
+
+
+ +

Discard everything after the answer in the results

+ +
+
+
219        results = [r.split('\n')[0] for r in results]
+
+
+
+
+ +

Log a sample

+ +
+
+
222        res_sample = results[0].split(';')
+223        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
+
+
+
+
+ +

Get the answers

+ +
+
+
226        results = [r.split('x==')[-1] for r in results]
+
+
+
+
+ +

Count the number of correct answers

+ +
+
+
229        correct = 0
+230        for r, _qa in zip(results, qa):
+231            if r == _qa[1]:
+232                correct += 1
+
+
+
+
+ +

Log the score

+ +
+
+
235        tracker.save('score', correct / len(results))
+
+
+
+
+ +

Training data loader

+ +
+
+
238@option(ArithmeticAutoregression.train_loader)
+239def arithmetic_train_loader(c: ArithmeticAutoregression):
+
+
+
+
+ + +
+
+
243    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
+244                      batch_size=c.batch_size,
+245                      collate_fn=transpose_batch,
+246                      num_workers=4)
+
+
+
+
+ +

Code to test generated problems

+ +
+
+
249def _test():
+
+
+
+
+ + +
+
+
253    dataset = ArithmeticDataset(256, 8, 10)
+254
+255    print(dataset.decode(dataset.get_packed_math_input()))
+
+
+
+
+ +

+ +
+
+
259if __name__ == '__main__':
+260    _test()
+
+
+ +
+ + + + \ No newline at end of file diff --git a/docs/normalization/deep_norm/experiment.html b/docs/normalization/deep_norm/experiment.html index 9f7e0095..8ccf394a 100644 --- a/docs/normalization/deep_norm/experiment.html +++ b/docs/normalization/deep_norm/experiment.html @@ -70,7 +70,8 @@ #

DeepNorm Experiment

-

Open In Colab View Run

+

Open In Colab View Run Open In Comet

+
15import copy
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 42cf8562..8f6077e4 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -204,7 +204,7 @@
 
     
       https://nn.labml.ai/normalization/deep_norm/index.html
-      2022-04-23T16:30:00+00:00
+      2022-05-18T16:30:00+00:00
       1.00
     
     
@@ -244,6 +244,13 @@
     
     
 
+    
+      https://nn.labml.ai/experiments/arithmetic_dataset.html
+      2022-06-02T16:30:00+00:00
+      1.00
+    
+    
+
     
       https://nn.labml.ai/experiments/index.html
       2020-12-26T16:30:00+00:00
@@ -603,14 +610,35 @@
 
     
       https://nn.labml.ai/transformers/rope/index.html
-      2022-04-05T16:30:00+00:00
+      2022-05-31T16:30:00+00:00
+      1.00
+    
+    
+
+    
+      https://nn.labml.ai/transformers/rope/value_pe/arithmetic_experiment.html
+      2022-06-02T16:30:00+00:00
+      1.00
+    
+    
+
+    
+      https://nn.labml.ai/transformers/rope/value_pe/index.html
+      2022-06-02T16:30:00+00:00
+      1.00
+    
+    
+
+    
+      https://nn.labml.ai/transformers/rope/value_pe/experiment.html
+      2022-05-31T16:30:00+00:00
       1.00
     
     
 
     
       https://nn.labml.ai/transformers/rope/experiment.html
-      2022-03-12T16:30:00+00:00
+      2022-05-31T16:30:00+00:00
       1.00
     
     
diff --git a/docs/transformers/rope/experiment.html b/docs/transformers/rope/experiment.html
index 156b2092..5ce9d242 100644
--- a/docs/transformers/rope/experiment.html
+++ b/docs/transformers/rope/experiment.html
@@ -92,7 +92,7 @@
         
21def _rotary_pe_mha(c: TransformerConfigs):
 22    from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
-23    return RotaryPEMultiHeadAttention(c.n_heads, c.d_model)
+23 return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 1.)
@@ -157,7 +157,7 @@
-
46    experiment.create(name="rotary_pe_transformer")
+
46    experiment.create(name="rotary_pe_transformer", writers={'screen'})
diff --git a/docs/transformers/rope/index.html b/docs/transformers/rope/index.html index 06512a47..6fbaa978 100644 --- a/docs/transformers/rope/index.html +++ b/docs/transformers/rope/index.html @@ -90,19 +90,19 @@ #

RoPE module

-

Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.

+

Rotary encoding transforms pairs of features by rotating in the 2D plane. That is, it organizes the features as pairs. Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it by an angle depending on the position of the token.

For a pair of features

-

Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,

-

where is a constant angle. The other pairs of features are transformed similarly.

+

Let and be two features of the key or query of any head at position . Or for simplicity assume has only two features. Then the transformation is,

+

where is a constant angle. The other pairs of features are transformed similarly.

Attention is relative

-

For a pair of features, dot-product attention score between two positions and would be

-

This shows that for dot-production attention the rotary encodings gives relative attention.

+

For a pair of features, dot-product attention score between two positions and would be

+

This shows that for dot-production attention the rotary encodings gives relative attention.

For all features

The features are grouped into pairs and handled as above. They use a different for each pair.

-

The paper suggests using for the pairs of features.

-

We pair feature with feature . So for position we transform

-

to

- +

The paper suggests using for the pairs of features.

+

We pair feature with feature . So for position we transform

+

to

+
32class RotaryPositionalEmbeddings(nn.Module):
@@ -114,13 +114,13 @@ #
-
118    def __init__(self, d: int, base: int = 10_000):
+
119    def __init__(self, d: int, base: int = 10_000):
@@ -131,33 +131,37 @@
-
123        super().__init__()
+
124        super().__init__()
+125
+126        self.base = base
+127        self.d = d
+128        self.cos_cached = None
+129        self.sin_cached = None
-
+
-

+

Cache and values

-
125        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
+
131    def _build_cache(self, x: torch.Tensor):
-
+
-
  • x - is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d] -
+

Return if cache is already built

-
127    def forward(self, x: torch.Tensor):
+
136        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
+137            return
@@ -165,11 +169,11 @@ -

Extract the shape

+

Get sequence length

-
132        seq_len, batch_size, n_heads, d = x.shape
+
140        seq_len = x.shape[0]
@@ -177,11 +181,11 @@ -

+

-
135        d_2 = d // 2
+
143        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
@@ -194,7 +198,7 @@
-
138        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
+
146        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
@@ -202,11 +206,11 @@ -

Calculate the product of position index and

+

Calculate the product of position index and

-
141        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
+
149        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
@@ -214,11 +218,11 @@ -

Concatenate so that for row we have

+

Concatenate so that for row we have

-
145        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
+
153        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
@@ -226,11 +230,12 @@ -

Calculate

+

Cache them

-
148        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
+
156        self.cos_cached = idx_theta2.cos()[:, None, None, :]
+157        self.sin_cached = idx_theta2.sin()[:, None, None, :]
@@ -238,12 +243,10 @@ -

Calculate

-

for

- +
-
160        rx = (x * idx_theta2.cos()[:, None, None, :]) + (neg_half_x * idx_theta2.sin()[:, None, None, :])
+
159    def _neg_half(self, x: torch.Tensor):
@@ -251,35 +254,37 @@ -

+

-
163        return rx
+
161        d_2 = self.d // 2
-
+
-

Multi-head attention with rotary positional embeddings

-

We override multi-head attention from original transformer.

+

Calculate

-
166class RotaryPEMultiHeadAttention(MultiHeadAttention):
+
164        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
-
+
- +
  • x + is the Tensor at the head of a key or a query with shape [seq_len, batch_size, n_heads, d] +
+
-
173    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
+
166    def forward(self, x: torch.Tensor):
@@ -287,12 +292,11 @@ -

The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for value - might make sense.

+

Cache and values

-
177        super().__init__(heads, d_model, dropout_prob, bias=False)
+
171        self._build_cache(x)
@@ -300,24 +304,23 @@ -

Rotary positional embedding layers

+

Split the features, we can choose to apply rotary embeddings only to a partial set of features.

-
180        self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
-181        self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
+
174        x_rope, x_pass = x[..., :self.d], x[..., self.d:]
-
+
-

Calculate scores between queries and keys

+

Calculate

-
183    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
178        neg_half_x = self._neg_half(x_rope)
@@ -325,43 +328,119 @@ +

Calculate

+

for

+ +
+
+
190        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
+
+
+
+
+ +

+ +
+
+
193        return torch.cat((x_rope, x_pass), dim=-1)
+
+
+
+
+ +

Multi-head attention with rotary positional embeddings

+

We override multi-head attention from original transformer.

+ +
+
+
196class RotaryPEMultiHeadAttention(MultiHeadAttention):
+
+
+
+
+ + +
+
+
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
+204        super().__init__(heads, d_model, dropout_prob)
+
+
+
+
+ +

Rotary positional embedding layers

+ +
+
+
207        d_rope = int(self.d_k * rope_percentage)
+208        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+209        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
+
+
+
+
+ +

Calculate scores between queries and keys

+ +
+
+
211    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
+
+
+
+

Calculate dot-product with RoPE

-
189        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
+
217        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
-
+

Testing RoPE with a simple example

-
192def _test_rotary():
+
220def _test_rotary():
-
+
-
196    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
-197    x = x[:, None, None, :]
-198    inspect(x)
-199
-200    rotary_pe = RotaryPositionalEmbeddings(3)
-201    inspect(rotary_pe(x))
-202
-203
-204if __name__ == '__main__':
-205    _test_rotary()
+
224    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
+225    x = x[:, None, None, :]
+226    inspect(x)
+227
+228    rotary_pe = RotaryPositionalEmbeddings(3)
+229    inspect(rotary_pe(x))
+230
+231
+232if __name__ == '__main__':
+233    _test_rotary()