This is based on code by Georges Harik (@gharik).
11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
This creates arithmetic addition problems and solutions with workings. We've only implemented addition so far.
It's based on a character level tokenization.
24class ArithmeticDataset(Dataset):
seq_len
is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch34 def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
41 self.n_sequences = n_sequences
42 self.max_digits = max_digits
43 self.seq_len = seq_len
Token id to string
45 self.itos = list(string.digits + 'xe =\n?+;')
Character to token id
47 self.stoi = {c: i for i, c in enumerate(self.itos)}
Generates an integer with n_digit
number of digits
49 @staticmethod
50 def make_int(n_digits: int):
54 res = 0
55 for i in range(n_digits):
56 d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57 res = res * 10 + d
58
59 return res
Generates the workings for x + y
. For example for 11+29
it generates 1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0
.
61 @staticmethod
62 def get_add_explanation(x: int, y: int):
69 carry = 0
70 e = 0
71 explanation = []
72 while x > 0 or y > 0 or carry > 0:
73 rx, ry = x % 10, y % 10
74 total = rx + ry + carry
75 explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76 x, y, carry = x // 10, y // 10, total // 10
77 e += 1
78
79 return ' '.join(explanation)
Make a problem with a pre_explanation or not
Creates an arithmetic addition problem with workings and answer.
82 def make_add_problem(self):
86 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89 explanation = self.get_add_explanation(x, y)
90 return f"x={x}+{y}; {explanation} x=={x + y}\n"
Get arithmetic problem and answer. This is used for evaluation.
92 def get_qa(self):
96 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99 return f'x={x}+{y};', f'{x + y}'
Generate multiple problems and pack them into a sequence.
101 def get_packed_math_input(self):
105 s_enc = []
106 while len(s_enc) <= self.seq_len:
107 s_part = self.make_add_problem()
108 s_part_enc = self.encode('?' + s_part)
109 s_enc = s_enc + s_part_enc
110 return s_enc
Encode a given string
112 def encode(self, s: str):
116 return [self.stoi[c] for c in s]
Decode a list of token ids
118 def decode(self, arr: List[int]):
122 return ''.join([self.itos[c] for c in arr])
Get a input and target pair for auto-regressive modelling
124 def __getitem__(self, idx: int):
128 s = torch.tensor(self.get_packed_math_input())
129 return s[:self.seq_len], s[1:self.seq_len + 1]
Number of sequences per epoch
131 def __len__(self):
135 return self.n_sequences
138class ArithmeticAutoregression(NLPAutoRegressionConfigs):
Maximum number of digits per operand integer
143 max_digits: int = 4
Number of training sequences per epoch
145 train_sequences_per_epoch: int = 2 ** 12
Training data loader
147 train_loader: DataLoader = 'arithmetic_train_loader'
Number of problems in evaluation
149 n_tests: int = 64
No need of a validation dataset
151 validator = None
Number of times to run evaluations per epoch
153 inner_iterations = 4
Number of tokens in the vocabulary
155 n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
157 @torch.no_grad()
158 def sample(self):
Skip in the first epoch
166 if self.training_loop.idx < 1:
167 return
Create a dataset to generate problems
170 dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
Get a set of problems and answers
172 qa = [dataset.get_qa() for _ in range(self.n_tests)]
Collect the problems only
174 questions = [p[0] for p in qa]
Create a tensor with only the initial token
177 data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
Move to device
179 data = data.to(self.device)
Number of sequences that have completed
182 finished = torch.zeros((len(questions),)).bool().to(self.device)
Token id of the new line character - this marks end of the answer
184 new_line = dataset.stoi['\n']
Sampled results
187 results = [p[0] for p in questions]
Sample upto sequence length
190 for i in monit.iterate('Sample', self.seq_len - 1):
If all the sequences have completed we skip this
192 if finished.sum() == len(finished):
193 continue
Get the model output
196 output, *_ = self.model(data)
Get the model prediction (greedy)
198 output = output[-1].argmax(dim=-1)
Find which sequences have finished
201 finished = finished | (output == new_line)
Skip if all have finished
203 if finished.sum() == len(finished):
204 continue
Override with the question
207 for j, p in enumerate(questions):
208 if len(p) > i + 1:
209 output[j] = dataset.stoi[p[i + 1]]
Add the next token to the input
212 data = torch.cat([data, output[None, :]], dim=0)
Get the sampled results
215 for j, c in enumerate(output):
216 results[j] += dataset.itos[c]
Discard everything after the answer in the results
219 results = [r.split('\n')[0] for r in results]
Log a sample
222 res_sample = results[0].split(';')
223 logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
Get the answers
226 results = [r.split('x==')[-1] for r in results]
Count the number of correct answers
229 correct = 0
230 for r, _qa in zip(results, qa):
231 if r == _qa[1]:
232 correct += 1
Log the score
235 tracker.save('score', correct / len(results))
Training data loader
238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):
243 return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244 batch_size=c.batch_size,
245 collate_fn=transpose_batch,
246 num_workers=4)
Code to test generated problems
249def _test():
253 dataset = ArithmeticDataset(256, 8, 10)
254
255 print(dataset.decode(dataset.get_packed_math_input()))
259if __name__ == '__main__':
260 _test()