This is based on code by Georges Harik (@gharik).
11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batchThis creates arithmetic addition problems and solutions with workings. We've only implemented addition so far.
It's based on a character level tokenization.
24class ArithmeticDataset(Dataset):seq_len
  is the sequence length of generated math problems.  We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch34    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):41        self.n_sequences = n_sequences
42        self.max_digits = max_digits
43        self.seq_len = seq_lenToken id to string
45        self.itos = list(string.digits + 'xe =\n?+;')Character to token id
47        self.stoi = {c: i for i, c in enumerate(self.itos)} Generates an integer with n_digit
 number of digits
49    @staticmethod
50    def make_int(n_digits: int):54        res = 0
55        for i in range(n_digits):
56            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57            res = res * 10 + d
58
59        return res Generates the workings for x + y
. For example for 11+29
 it generates 1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0
.
61    @staticmethod
62    def get_add_explanation(x: int, y: int):69        carry = 0
70        e = 0
71        explanation = []
72        while x > 0 or y > 0 or carry > 0:
73            rx, ry = x % 10, y % 10
74            total = rx + ry + carry
75            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76            x, y, carry = x // 10, y // 10, total // 10
77            e += 1
78
79        return ' '.join(explanation)Make a problem with a pre_explanation or not
Creates an arithmetic addition problem with workings and answer.
82    def make_add_problem(self):86        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89        explanation = self.get_add_explanation(x, y)
90        return f"x={x}+{y}; {explanation} x=={x + y}\n"Get arithmetic problem and answer. This is used for evaluation.
92    def get_qa(self):96        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99        return f'x={x}+{y};', f'{x + y}'Generate multiple problems and pack them into a sequence.
101    def get_packed_math_input(self):105        s_enc = []
106        while len(s_enc) <= self.seq_len:
107            s_part = self.make_add_problem()
108            s_part_enc = self.encode('?' + s_part)
109            s_enc = s_enc + s_part_enc
110        return s_encEncode a given string
112    def encode(self, s: str):116        return [self.stoi[c] for c in s]Decode a list of token ids
118    def decode(self, arr: List[int]):122        return ''.join([self.itos[c] for c in arr])Get a input and target pair for auto-regressive modelling
124    def __getitem__(self, idx: int):128        s = torch.tensor(self.get_packed_math_input())
129        return s[:self.seq_len], s[1:self.seq_len + 1]Number of sequences per epoch
131    def __len__(self):135        return self.n_sequences138class ArithmeticAutoregression(NLPAutoRegressionConfigs):Maximum number of digits per operand integer
143    max_digits: int = 4Number of training sequences per epoch
145    train_sequences_per_epoch: int = 2 ** 12Training data loader
147    train_loader: DataLoader = 'arithmetic_train_loader'Number of problems in evaluation
149    n_tests: int = 64No need of a validation dataset
151    validator = NoneNumber of times to run evaluations per epoch
153    inner_iterations = 4Number of tokens in the vocabulary
155    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)157    @torch.no_grad()
158    def sample(self):Skip in the first epoch
166        if self.training_loop.idx < 1:
167            returnCreate a dataset to generate problems
170        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)Get a set of problems and answers
172        qa = [dataset.get_qa() for _ in range(self.n_tests)]Collect the problems only
174        questions = [p[0] for p in qa]Create a tensor with only the initial token
177        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])Move to device
179        data = data.to(self.device)Number of sequences that have completed
182        finished = torch.zeros((len(questions),)).bool().to(self.device)Token id of the new line character - this marks end of the answer
184        new_line = dataset.stoi['\n']Sampled results
187        results = [p[0] for p in questions]Sample upto sequence length
190        for i in monit.iterate('Sample', self.seq_len - 1):If all the sequences have completed we skip this
192            if finished.sum() == len(finished):
193                continueGet the model output
196            output, *_ = self.model(data)Get the model prediction (greedy)
198            output = output[-1].argmax(dim=-1)Find which sequences have finished
201            finished = finished | (output == new_line)Skip if all have finished
203            if finished.sum() == len(finished):
204                continueOverride with the question
207            for j, p in enumerate(questions):
208                if len(p) > i + 1:
209                    output[j] = dataset.stoi[p[i + 1]]Add the next token to the input
212            data = torch.cat([data, output[None, :]], dim=0)Get the sampled results
215            for j, c in enumerate(output):
216                results[j] += dataset.itos[c]Discard everything after the answer in the results
219        results = [r.split('\n')[0] for r in results]Log a sample
222        res_sample = results[0].split(';')
223        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])Get the answers
226        results = [r.split('x==')[-1] for r in results]Count the number of correct answers
229        correct = 0
230        for r, _qa in zip(results, qa):
231            if r == _qa[1]:
232                correct += 1Log the score
235        tracker.save('score', correct / len(results))Training data loader
238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):243    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244                      batch_size=c.batch_size,
245                      collate_fn=transpose_batch,
246                      num_workers=4)Code to test generated problems
249def _test():253    dataset = ArithmeticDataset(256, 8, 10)
254
255    print(dataset.decode(dataset.get_packed_math_input()))259if __name__ == '__main__':
260    _test()