diff --git a/docs/experiments/arithmetic_dataset.html b/docs/experiments/arithmetic_dataset.html index 1641927a..c35bbd9a 100644 --- a/docs/experiments/arithmetic_dataset.html +++ b/docs/experiments/arithmetic_dataset.html @@ -447,7 +447,7 @@
-
149    n_tests: int = 32
+
149    n_tests: int = 64
@@ -496,7 +496,8 @@
-
157    def sample(self):
+
157    @torch.no_grad()
+158    def sample(self):
@@ -508,8 +509,8 @@
-
165        if self.training_loop.idx < 1:
-166            return
+
166        if self.training_loop.idx < 1:
+167            return
@@ -521,7 +522,7 @@
-
169        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
+
170        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
@@ -533,7 +534,7 @@
-
171        qa = [dataset.get_qa() for _ in range(self.n_tests)]
+
172        qa = [dataset.get_qa() for _ in range(self.n_tests)]
@@ -545,7 +546,7 @@
-
173        questions = [p[0] for p in qa]
+
174        questions = [p[0] for p in qa]
@@ -557,7 +558,7 @@
-
176        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
+
177        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
@@ -569,7 +570,7 @@
-
178        data = data.to(self.device)
+
179        data = data.to(self.device)
@@ -581,7 +582,7 @@
-
181        finished = torch.zeros((len(questions),)).bool().to(self.device)
+
182        finished = torch.zeros((len(questions),)).bool().to(self.device)
@@ -593,7 +594,7 @@
-
183        new_line = dataset.stoi['\n']
+
184        new_line = dataset.stoi['\n']
@@ -605,7 +606,7 @@
-
186        results = [p[0] for p in questions]
+
187        results = [p[0] for p in questions]
@@ -617,7 +618,7 @@
-
189        for i in monit.iterate('Sample', self.seq_len - 1):
+
190        for i in monit.iterate('Sample', self.seq_len - 1):
@@ -629,8 +630,8 @@
-
191            if finished.sum() == len(finished):
-192                continue
+
192            if finished.sum() == len(finished):
+193                continue
@@ -642,7 +643,7 @@
-
195            output, *_ = self.model(data)
+
196            output, *_ = self.model(data)
@@ -654,7 +655,7 @@
-
197            output = output[-1].argmax(dim=-1)
+
198            output = output[-1].argmax(dim=-1)
@@ -666,7 +667,7 @@
-
200            finished = finished | (output == new_line)
+
201            finished = finished | (output == new_line)
@@ -678,8 +679,8 @@
-
202            if finished.sum() == len(finished):
-203                continue
+
203            if finished.sum() == len(finished):
+204                continue
@@ -691,9 +692,9 @@
-
206            for j, p in enumerate(questions):
-207                if len(p) > i + 1:
-208                    output[j] = dataset.stoi[p[i + 1]]
+
207            for j, p in enumerate(questions):
+208                if len(p) > i + 1:
+209                    output[j] = dataset.stoi[p[i + 1]]
@@ -705,7 +706,7 @@
-
211            data = torch.cat([data, output[None, :]], dim=0)
+
212            data = torch.cat([data, output[None, :]], dim=0)
@@ -717,8 +718,8 @@
-
214            for j, c in enumerate(output):
-215                results[j] += dataset.itos[c]
+
215            for j, c in enumerate(output):
+216                results[j] += dataset.itos[c]
@@ -730,7 +731,7 @@
-
218        results = [r.split('\n')[0] for r in results]
+
219        results = [r.split('\n')[0] for r in results]
@@ -742,8 +743,8 @@
-
221        res_sample = results[0].split(';')
-222        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
+
222        res_sample = results[0].split(';')
+223        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
@@ -755,7 +756,7 @@
-
225        results = [r.split('x==')[-1] for r in results]
+
226        results = [r.split('x==')[-1] for r in results]
@@ -767,10 +768,10 @@
-
228        correct = 0
-229        for r, _qa in zip(results, qa):
-230            if r == _qa[1]:
-231                correct += 1
+
229        correct = 0
+230        for r, _qa in zip(results, qa):
+231            if r == _qa[1]:
+232                correct += 1
@@ -782,7 +783,7 @@
-
234        tracker.save('score', correct / len(results))
+
235        tracker.save('score', correct / len(results))
@@ -794,8 +795,8 @@
-
237@option(ArithmeticAutoregression.train_loader)
-238def arithmetic_train_loader(c: ArithmeticAutoregression):
+
238@option(ArithmeticAutoregression.train_loader)
+239def arithmetic_train_loader(c: ArithmeticAutoregression):
@@ -806,10 +807,10 @@
-
242    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
-243                      batch_size=c.batch_size,
-244                      collate_fn=transpose_batch,
-245                      num_workers=4)
+
243    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
+244                      batch_size=c.batch_size,
+245                      collate_fn=transpose_batch,
+246                      num_workers=4)
@@ -821,7 +822,7 @@
-
248def _test():
+
249def _test():
@@ -832,9 +833,9 @@
-
252    dataset = ArithmeticDataset(256, 8, 10)
-253
-254    print(dataset.decode(dataset.get_packed_math_input()))
+
253    dataset = ArithmeticDataset(256, 8, 10)
+254
+255    print(dataset.decode(dataset.get_packed_math_input()))
@@ -846,8 +847,8 @@
-
258if __name__ == '__main__':
-259    _test()
+
259if __name__ == '__main__':
+260    _test()
203    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
-204        super().__init__(heads, d_model, dropout_prob, bias=False)
+204 super().__init__(heads, d_model, dropout_prob)
diff --git a/docs/transformers/rope/value_pe/arithmetic_experiment.html b/docs/transformers/rope/value_pe/arithmetic_experiment.html index 45d9d634..bd52b4dc 100644 --- a/docs/transformers/rope/value_pe/arithmetic_experiment.html +++ b/docs/transformers/rope/value_pe/arithmetic_experiment.html @@ -163,7 +163,7 @@
-
45    experiment.create(name="roper_addition", comment="rotary value 8", writers={'screen', 'labml', 'comet'})
+
45    experiment.create(name="roper_addition", comment="rotary value 7", writers={'screen', 'labml', 'comet'})
@@ -188,7 +188,7 @@
49    experiment.configs(conf, {
-50        'max_digits': 8,
+50 'max_digits': 7,
@@ -296,12 +296,12 @@ -

Use Noam optimizer

+

Use Adam optimizer

-
78        'optimizer.optimizer': 'Noam',
-79        'optimizer.learning_rate': 1.,
+            
78        'optimizer.optimizer': 'Adam',
+79        'optimizer.learning_rate': 2.5e-4,
 80    })
diff --git a/docs/transformers/rope/value_pe/experiment.html b/docs/transformers/rope/value_pe/experiment.html index 76690ddb..109a7605 100644 --- a/docs/transformers/rope/value_pe/experiment.html +++ b/docs/transformers/rope/value_pe/experiment.html @@ -116,7 +116,7 @@
26def _rotary_value_pe_mha(c: TransformerConfigs):
 27    from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
-28    return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 0.5)
+28 return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
@@ -153,7 +153,7 @@
-
39    experiment.create(name="rotary_pe_transformer", comment="rotary_value 1.0, 0.5", writers={'screen', 'labml'})
+
39    experiment.create(name="rotary_shakespeare", comment="rotary value", writers={'screen', 'labml'})
@@ -286,7 +286,7 @@
-
65        'seq_len': 128,
+
65        'seq_len': 512,
@@ -298,7 +298,7 @@
-
67        'epochs': 32,
+
67        'epochs': 24,
@@ -310,7 +310,7 @@
-
69        'batch_size': 4,
+
69        'batch_size': 16,
@@ -322,7 +322,7 @@
-
72        'inner_iterations': 10,
+
72        'inner_iterations': 4,
@@ -334,9 +334,9 @@
-
75        'd_model': 256,
-76        'transformer.ffn.d_ff': 1024,
-77        'transformer.n_heads': 8,
+            
75        'd_model': 128,
+76        'transformer.ffn.d_ff': 512,
+77        'transformer.n_heads': 4,
 78        'transformer.dropout': 0.0,
@@ -345,12 +345,12 @@ -

Use Noam optimizer

+

Use Adam optimizer

-
81        'optimizer.optimizer': 'Noam',
-82        'optimizer.learning_rate': 1.,
+            
81        'optimizer.optimizer': 'Adam',
+82        'optimizer.learning_rate': 2.5e-4,
 83
 84        'dataloader_shuffle_with_replacement': True
 85    })
diff --git a/docs/transformers/rope/value_pe/index.html b/docs/transformers/rope/value_pe/index.html index 30281ed9..fa4f0e44 100644 --- a/docs/transformers/rope/value_pe/index.html +++ b/docs/transformers/rope/value_pe/index.html @@ -97,8 +97,7 @@ 119 120import torch 121 -122from labml_nn.transformers.mha import MultiHeadAttention -123from labml_nn.transformers.rope import RotaryPositionalEmbeddings
+122from labml_nn.transformers.rope import RotaryPositionalEmbeddings, RotaryPEMultiHeadAttention
@@ -111,7 +110,7 @@
-
126class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
+
125class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
@@ -125,7 +124,7 @@
-
133    def forward(self, x: torch.Tensor):
+
132    def forward(self, x: torch.Tensor):
@@ -137,7 +136,7 @@
-
138        self._build_cache(x)
+
137        self._build_cache(x)
@@ -149,7 +148,7 @@
-
141        x_rope, x_pass = x[..., :self.d], x[..., self.d:]
+
140        x_rope, x_pass = x[..., :self.d], x[..., self.d:]
@@ -161,7 +160,7 @@
-
145        neg_half_x = self._neg_half(x_rope)
+
144        neg_half_x = self._neg_half(x_rope)
@@ -174,7 +173,7 @@
-
161        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
+
160        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
@@ -186,7 +185,7 @@
-
164        return torch.cat((x_rope, x_pass), dim=-1)
+
163        return torch.cat((x_rope, x_pass), dim=-1)
@@ -199,7 +198,7 @@
-
167class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
+
166class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):
@@ -210,10 +209,10 @@
-
174    def __init__(self, heads: int, d_model: int,
-175                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
-176                 dropout_prob: float = 0.0):
-177        super().__init__(heads, d_model, dropout_prob, bias=False)
+
173    def __init__(self, heads: int, d_model: int,
+174                 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
+175                 dropout_prob: float = 0.0):
+176        super().__init__(heads, d_model, rope_percentage, dropout_prob)
@@ -225,13 +224,10 @@
-
180        d_rope = int(self.d_k * rope_percentage)
-181        d_rope_value = int(self.d_k * rope_value_percentage)
-182
-183        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-184        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-185        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
-186        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
+
179        d_rope_value = int(self.d_k * rope_value_percentage)
+180
+181        self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
+182        self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
@@ -239,30 +235,6 @@ -

Calculate scores between queries and keys

- -
-
-
188    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
-
- -
-
- -

Calculate dot-product with RoPE

- -
-
-
194        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
-
-
-
-
-

query , key and value @@ -278,17 +250,17 @@

-
196    def forward(self, *,
-197                query: torch.Tensor,
-198                key: torch.Tensor,
-199                value: torch.Tensor,
-200                mask: Optional[torch.Tensor] = None):
+
184    def forward(self, *,
+185                query: torch.Tensor,
+186                key: torch.Tensor,
+187                value: torch.Tensor,
+188                mask: Optional[torch.Tensor] = None):
-
+

query , key @@ -298,16 +270,16 @@

-
212        seq_len, batch_size, _ = query.shape
-213
-214        if mask is not None:
-215            mask = self.prepare_mask(mask, query.shape, key.shape)
+
200        seq_len, batch_size, _ = query.shape
+201
+202        if mask is not None:
+203            mask = self.prepare_mask(mask, query.shape, key.shape)
-
+

Prepare query , key @@ -317,28 +289,28 @@

-
219        query = self.query(query)
-220        key = self.key(key)
-221        value = self.value(value)
+
207        query = self.query(query)
+208        key = self.key(key)
+209        value = self.value(value)
-
+

Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

-
225        scores = self.get_scores(query, key)
+
213        scores = self.get_scores(query, key)
-
+

Scale scores

228        scores *= self.scale
+
216        scores *= self.scale
-
+

Apply mask

-
231        if mask is not None:
-232            scores = scores.masked_fill(mask == 0, float('-inf'))
+
219        if mask is not None:
+220            scores = scores.masked_fill(mask == 0, float('-inf'))
-
+

attention along the key sequence dimension

236        attn = self.softmax(scores)
+
224        attn = self.softmax(scores)
+
+
+
+
+ +

Apply dropout

+ +
+
+
227        attn = self.dropout(attn)
+
+
+
+
+ +

Rotate value embeddings before taking the weighted sum so that they contain positional information

+ +
+
+
230        value = self.value_rotary_pe(value)
@@ -399,30 +395,6 @@ M834 80h400000v40h-400000z">
239        attn = self.dropout(attn)
-
-
-
-
- -

Rotate value embeddings before taking the weighted sum so that they contain positional information

- -
-
-
242        value = self.value_rotary_pe(value)
-
-
-
-
-

Multiply by values

246        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
+
234        x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
+
+
+
+
+ +

Rotate in the opposite direction so that each embedding hold the relative positions

+ +
+
+
237        x = self.value_reverse_rotary_pe(x)
+
+
+
+
+ +

Save attentions for any other calculations

+ +
+
+
240        self.attn = attn.detach()
@@ -446,11 +442,11 @@ M834 80h400000v40h-400000z">
249        x = self.value_reverse_rotary_pe(x)
+
243        x = x.reshape(seq_len, batch_size, -1)
@@ -458,35 +454,11 @@ M834 80h400000v40h-400000z">
252        self.attn = attn.detach()
- - -
-
- -

Concatenate multiple heads

- -
-
-
255        x = x.reshape(seq_len, batch_size, -1)
-
-
-
-
-

Output layer

-
258        return self.output(x)
+
246        return self.output(x)