diff --git a/docs/experiments/arithmetic_dataset.html b/docs/experiments/arithmetic_dataset.html index 1641927a..c35bbd9a 100644 --- a/docs/experiments/arithmetic_dataset.html +++ b/docs/experiments/arithmetic_dataset.html @@ -447,7 +447,7 @@
149 n_tests: int = 32
149 n_tests: int = 64
157 def sample(self):
157 @torch.no_grad()
+158 def sample(self):
165 if self.training_loop.idx < 1:
-166 return
166 if self.training_loop.idx < 1:
+167 return
169 dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
170 dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
171 qa = [dataset.get_qa() for _ in range(self.n_tests)]
172 qa = [dataset.get_qa() for _ in range(self.n_tests)]
173 questions = [p[0] for p in qa]
174 questions = [p[0] for p in qa]
176 data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
177 data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
178 data = data.to(self.device)
179 data = data.to(self.device)
181 finished = torch.zeros((len(questions),)).bool().to(self.device)
182 finished = torch.zeros((len(questions),)).bool().to(self.device)
183 new_line = dataset.stoi['\n']
184 new_line = dataset.stoi['\n']
186 results = [p[0] for p in questions]
187 results = [p[0] for p in questions]
189 for i in monit.iterate('Sample', self.seq_len - 1):
190 for i in monit.iterate('Sample', self.seq_len - 1):
191 if finished.sum() == len(finished):
-192 continue
192 if finished.sum() == len(finished):
+193 continue
195 output, *_ = self.model(data)
196 output, *_ = self.model(data)
197 output = output[-1].argmax(dim=-1)
198 output = output[-1].argmax(dim=-1)
200 finished = finished | (output == new_line)
201 finished = finished | (output == new_line)
202 if finished.sum() == len(finished):
-203 continue
203 if finished.sum() == len(finished):
+204 continue
206 for j, p in enumerate(questions):
-207 if len(p) > i + 1:
-208 output[j] = dataset.stoi[p[i + 1]]
207 for j, p in enumerate(questions):
+208 if len(p) > i + 1:
+209 output[j] = dataset.stoi[p[i + 1]]
211 data = torch.cat([data, output[None, :]], dim=0)
212 data = torch.cat([data, output[None, :]], dim=0)
214 for j, c in enumerate(output):
-215 results[j] += dataset.itos[c]
215 for j, c in enumerate(output):
+216 results[j] += dataset.itos[c]
218 results = [r.split('\n')[0] for r in results]
219 results = [r.split('\n')[0] for r in results]
221 res_sample = results[0].split(';')
-222 logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
222 res_sample = results[0].split(';')
+223 logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
225 results = [r.split('x==')[-1] for r in results]
226 results = [r.split('x==')[-1] for r in results]
228 correct = 0
-229 for r, _qa in zip(results, qa):
-230 if r == _qa[1]:
-231 correct += 1
229 correct = 0
+230 for r, _qa in zip(results, qa):
+231 if r == _qa[1]:
+232 correct += 1
234 tracker.save('score', correct / len(results))
235 tracker.save('score', correct / len(results))
237@option(ArithmeticAutoregression.train_loader)
-238def arithmetic_train_loader(c: ArithmeticAutoregression):
238@option(ArithmeticAutoregression.train_loader)
+239def arithmetic_train_loader(c: ArithmeticAutoregression):
242 return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
-243 batch_size=c.batch_size,
-244 collate_fn=transpose_batch,
-245 num_workers=4)
243 return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
+244 batch_size=c.batch_size,
+245 collate_fn=transpose_batch,
+246 num_workers=4)
248def _test():
249def _test():
252 dataset = ArithmeticDataset(256, 8, 10)
-253
-254 print(dataset.decode(dataset.get_packed_math_input()))
253 dataset = ArithmeticDataset(256, 8, 10)
+254
+255 print(dataset.decode(dataset.get_packed_math_input()))
258if __name__ == '__main__':
-259 _test()
259if __name__ == '__main__':
+260 _test()
203 def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
-204 super().__init__(heads, d_model, dropout_prob, bias=False)
45 experiment.create(name="roper_addition", comment="rotary value 8", writers={'screen', 'labml', 'comet'})
45 experiment.create(name="roper_addition", comment="rotary value 7", writers={'screen', 'labml', 'comet'})
49 experiment.configs(conf, {
-50 'max_digits': 8,
78 'optimizer.optimizer': 'Noam',
-79 'optimizer.learning_rate': 1.,
+ 78 'optimizer.optimizer': 'Adam',
+79 'optimizer.learning_rate': 2.5e-4,
80 })
26def _rotary_value_pe_mha(c: TransformerConfigs):
27 from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
-28 return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 0.5)
39 experiment.create(name="rotary_pe_transformer", comment="rotary_value 1.0, 0.5", writers={'screen', 'labml'})
39 experiment.create(name="rotary_shakespeare", comment="rotary value", writers={'screen', 'labml'})
65 'seq_len': 128,
65 'seq_len': 512,
67 'epochs': 32,
67 'epochs': 24,
69 'batch_size': 4,
69 'batch_size': 16,
72 'inner_iterations': 10,
72 'inner_iterations': 4,
75 'd_model': 256,
-76 'transformer.ffn.d_ff': 1024,
-77 'transformer.n_heads': 8,
+ 75 'd_model': 128,
+76 'transformer.ffn.d_ff': 512,
+77 'transformer.n_heads': 4,
78 'transformer.dropout': 0.0,
Use Noam optimizer
+Use Adam optimizer
81 'optimizer.optimizer': 'Noam',
-82 'optimizer.learning_rate': 1.,
+ 81 'optimizer.optimizer': 'Adam',
+82 'optimizer.learning_rate': 2.5e-4,
83
84 'dataloader_shuffle_with_replacement': True
85 })
diff --git a/docs/transformers/rope/value_pe/index.html b/docs/transformers/rope/value_pe/index.html
index 30281ed9..fa4f0e44 100644
--- a/docs/transformers/rope/value_pe/index.html
+++ b/docs/transformers/rope/value_pe/index.html
@@ -97,8 +97,7 @@
119
120import torch
121
-122from labml_nn.transformers.mha import MultiHeadAttention
-123from labml_nn.transformers.rope import RotaryPositionalEmbeddings
126class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
125class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
133 def forward(self, x: torch.Tensor):
132 def forward(self, x: torch.Tensor):
138 self._build_cache(x)
137 self._build_cache(x)
141 x_rope, x_pass = x[..., :self.d], x[..., self.d:]
140 x_rope, x_pass = x[..., :self.d], x[..., self.d:]
145 neg_half_x = self._neg_half(x_rope)
144 neg_half_x = self._neg_half(x_rope)
161 x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
160 x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
164 return torch.cat((x_rope, x_pass), dim=-1)
163 return torch.cat((x_rope, x_pass), dim=-1)
167class RotaryValuePEMultiHeadAttention(MultiHeadAttention):
166class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):
174 def __init__(self, heads: int, d_model: int,
-175 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
-176 dropout_prob: float = 0.0):
-177 super().__init__(heads, d_model, dropout_prob, bias=False)
173 def __init__(self, heads: int, d_model: int,
+174 rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
+175 dropout_prob: float = 0.0):
+176 super().__init__(heads, d_model, rope_percentage, dropout_prob)
180 d_rope = int(self.d_k * rope_percentage)
-181 d_rope_value = int(self.d_k * rope_value_percentage)
-182
-183 self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-184 self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
-185 self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
-186 self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
179 d_rope_value = int(self.d_k * rope_value_percentage)
+180
+181 self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
+182 self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
188 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
Calculate dot-product with RoPE
- -194 return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
query
, key
and value
@@ -278,17 +250,17 @@
196 def forward(self, *,
-197 query: torch.Tensor,
-198 key: torch.Tensor,
-199 value: torch.Tensor,
-200 mask: Optional[torch.Tensor] = None):
184 def forward(self, *,
+185 query: torch.Tensor,
+186 key: torch.Tensor,
+187 value: torch.Tensor,
+188 mask: Optional[torch.Tensor] = None):
212 seq_len, batch_size, _ = query.shape
-213
-214 if mask is not None:
-215 mask = self.prepare_mask(mask, query.shape, key.shape)
200 seq_len, batch_size, _ = query.shape
+201
+202 if mask is not None:
+203 mask = self.prepare_mask(mask, query.shape, key.shape)
219 query = self.query(query)
-220 key = self.key(key)
-221 value = self.value(value)
207 query = self.query(query)
+208 key = self.key(key)
+209 value = self.value(value)
Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
.
225 scores = self.get_scores(query, key)
213 scores = self.get_scores(query, key)
231 if mask is not None:
-232 scores = scores.masked_fill(mask == 0, float('-inf'))
219 if mask is not None:
+220 scores = scores.masked_fill(mask == 0, float('-inf'))
Apply dropout
+ +227 attn = self.dropout(attn)
Rotate value embeddings before taking the weighted sum so that they contain positional information
+ +230 value = self.value_rotary_pe(value)
239 attn = self.dropout(attn)
Rotate value embeddings before taking the weighted sum so that they contain positional information
- -242 value = self.value_rotary_pe(value)
Multiply by values
234 x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
Rotate in the opposite direction so that each embedding hold the relative positions
+ +237 x = self.value_reverse_rotary_pe(x)
Save attentions for any other calculations
+ +240 self.attn = attn.detach()
243 x = x.reshape(seq_len, batch_size, -1)
252 self.attn = attn.detach()
Concatenate multiple heads
- -255 x = x.reshape(seq_len, batch_size, -1)
Output layer
258 return self.output(x)
246 return self.output(x)