This is based on code by Georges Harik (@gharik).

11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch

Arithmetic Dataset

This creates arithmetic addition problems and solutions with workings. We've only implemented addition so far.

It's based on a character level tokenization.

24class ArithmeticDataset(Dataset):
  • seq_len is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch
34    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
41        self.n_sequences = n_sequences
42        self.max_digits = max_digits
43        self.seq_len = seq_len

Token id to string

45        self.itos = list(string.digits + 'xe =\n?+;')

Character to token id

47        self.stoi = {c: i for i, c in enumerate(self.itos)}

Generates an integer with n_digit number of digits

49    @staticmethod
50    def make_int(n_digits: int):
54        res = 0
55        for i in range(n_digits):
56            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57            res = res * 10 + d
58
59        return res

Generates the workings for x + y . For example for 11+29 it generates 1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0 .

61    @staticmethod
62    def get_add_explanation(x: int, y: int):
69        carry = 0
70        e = 0
71        explanation = []
72        while x > 0 or y > 0 or carry > 0:
73            rx, ry = x % 10, y % 10
74            total = rx + ry + carry
75            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76            x, y, carry = x // 10, y // 10, total // 10
77            e += 1
78
79        return ' '.join(explanation)

Make a problem with a pre_explanation or not

Creates an arithmetic addition problem with workings and answer.

82    def make_add_problem(self):
86        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89        explanation = self.get_add_explanation(x, y)
90        return f"x={x}+{y}; {explanation} x=={x + y}\n"

Get arithmetic problem and answer. This is used for evaluation.

92    def get_qa(self):
96        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99        return f'x={x}+{y};', f'{x + y}'

Generate multiple problems and pack them into a sequence.

101    def get_packed_math_input(self):
105        s_enc = []
106        while len(s_enc) <= self.seq_len:
107            s_part = self.make_add_problem()
108            s_part_enc = self.encode('?' + s_part)
109            s_enc = s_enc + s_part_enc
110        return s_enc

Encode a given string

112    def encode(self, s: str):
116        return [self.stoi[c] for c in s]

Decode a list of token ids

118    def decode(self, arr: List[int]):
122        return ''.join([self.itos[c] for c in arr])

Get a input and target pair for auto-regressive modelling

124    def __getitem__(self, idx: int):
128        s = torch.tensor(self.get_packed_math_input())
129        return s[:self.seq_len], s[1:self.seq_len + 1]

Number of sequences per epoch

131    def __len__(self):
135        return self.n_sequences

Arithmetic Task Experiment Configurations

138class ArithmeticAutoregression(NLPAutoRegressionConfigs):

Maximum number of digits per operand integer

143    max_digits: int = 4

Number of training sequences per epoch

145    train_sequences_per_epoch: int = 2 ** 12

Training data loader

147    train_loader: DataLoader = 'arithmetic_train_loader'

Number of problems in evaluation

149    n_tests: int = 64

No need of a validation dataset

151    validator = None

Number of times to run evaluations per epoch

153    inner_iterations = 4

Number of tokens in the vocabulary

155    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)

Evaluation

We use the sampling function to evaluate the model on a set of problems

157    @torch.no_grad()
158    def sample(self):

Skip in the first epoch

166        if self.training_loop.idx < 1:
167            return

Create a dataset to generate problems

170        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)

Get a set of problems and answers

172        qa = [dataset.get_qa() for _ in range(self.n_tests)]

Collect the problems only

174        questions = [p[0] for p in qa]

Create a tensor with only the initial token

177        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])

Move to device

179        data = data.to(self.device)

Number of sequences that have completed

182        finished = torch.zeros((len(questions),)).bool().to(self.device)

Token id of the new line character - this marks end of the answer

184        new_line = dataset.stoi['\n']

Sampled results

187        results = [p[0] for p in questions]

Sample upto sequence length

190        for i in monit.iterate('Sample', self.seq_len - 1):

If all the sequences have completed we skip this

192            if finished.sum() == len(finished):
193                continue

Get the model output

196            output, *_ = self.model(data)

Get the model prediction (greedy)

198            output = output[-1].argmax(dim=-1)

Find which sequences have finished

201            finished = finished | (output == new_line)

Skip if all have finished

203            if finished.sum() == len(finished):
204                continue

Override with the question

207            for j, p in enumerate(questions):
208                if len(p) > i + 1:
209                    output[j] = dataset.stoi[p[i + 1]]

Add the next token to the input

212            data = torch.cat([data, output[None, :]], dim=0)

Get the sampled results

215            for j, c in enumerate(output):
216                results[j] += dataset.itos[c]

Discard everything after the answer in the results

219        results = [r.split('\n')[0] for r in results]

Log a sample

222        res_sample = results[0].split(';')
223        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])

Get the answers

226        results = [r.split('x==')[-1] for r in results]

Count the number of correct answers

229        correct = 0
230        for r, _qa in zip(results, qa):
231            if r == _qa[1]:
232                correct += 1

Log the score

235        tracker.save('score', correct / len(results))

Training data loader

238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):
243    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244                      batch_size=c.batch_size,
245                      collate_fn=transpose_batch,
246                      num_workers=4)

Code to test generated problems

249def _test():
253    dataset = ArithmeticDataset(256, 8, 10)
254
255    print(dataset.decode(dataset.get_packed_math_input()))

259if __name__ == '__main__':
260    _test()