这是基于乔治·哈里克(@gharik)的代码。

11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch

算术数据集

这会产生算术加法问题和运作解决方案。到目前为止,我们只实施了加法。

它基于角色级别的标记化。

24class ArithmeticDataset(Dataset):
  • seq_len 是生成的数学问题的序列长度。我们尽可能多地填写问题,直到这个长度:max_digits: 是操作数中的最大位数整数:n_sequences: 是每个纪元的序列数
34    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
41        self.n_sequences = n_sequences
42        self.max_digits = max_digits
43        self.seq_len = seq_len

令牌 ID 转换为字符串

45        self.itos = list(string.digits + 'xe =\n?+;')

字符到令牌 ID

47        self.stoi = {c: i for i, c in enumerate(self.itos)}

生成一个包含位n_digit 数的整数

49    @staticmethod
50    def make_int(n_digits: int):
54        res = 0
55        for i in range(n_digits):
56            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57            res = res * 10 + d
58
59        return res

生成的工作原理x + y 。例如,11+29 它会生成1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0

61    @staticmethod
62    def get_add_explanation(x: int, y: int):
69        carry = 0
70        e = 0
71        explanation = []
72        while x > 0 or y > 0 or carry > 0:
73            rx, ry = x % 10, y % 10
74            total = rx + ry + carry
75            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76            x, y, carry = x // 10, y // 10, total // 10
77            e += 1
78
79        return ' '.join(explanation)

不管是否用 pre_explansion 问问题

用运作和答案创建算术加法问题。

82    def make_add_problem(self):
86        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89        explanation = self.get_add_explanation(x, y)
90        return f"x={x}+{y}; {explanation} x=={x + y}\n"

获取算术问题和答案。这用于评估。

92    def get_qa(self):
96        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99        return f'x={x}+{y};', f'{x + y}'

生成多个问题并将它们打包成一个序列。

101    def get_packed_math_input(self):
105        s_enc = []
106        while len(s_enc) <= self.seq_len:
107            s_part = self.make_add_problem()
108            s_part_enc = self.encode('?' + s_part)
109            s_enc = s_enc + s_part_enc
110        return s_enc

对给定字符串进行编码

112    def encode(self, s: str):
116        return [self.stoi[c] for c in s]

解码令牌 ID 列表

118    def decode(self, arr: List[int]):
122        return ''.join([self.itos[c] for c in arr])

获取自动回归建模的输入和目标对

124    def __getitem__(self, idx: int):
128        s = torch.tensor(self.get_packed_math_input())
129        return s[:self.seq_len], s[1:self.seq_len + 1]

每个纪元的序列数

131    def __len__(self):
135        return self.n_sequences

算术任务实验配置

138class ArithmeticAutoregression(NLPAutoRegressionConfigs):

每个操作数整数的最大位数

143    max_digits: int = 4

每个纪元的训练序列数

145    train_sequences_per_epoch: int = 2 ** 12

训练数据加载器

147    train_loader: DataLoader = 'arithmetic_train_loader'

评估中的问题数量

149    n_tests: int = 64

不需要验证数据集

151    validator = None

每个纪元运行评估的次数

153    inner_iterations = 4

词汇表中的代币数量

155    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)

评估

我们使用采样函数来评估一组问题的模型

157    @torch.no_grad()
158    def sample(self):

跳过第一个纪元

166        if self.training_loop.idx < 1:
167            return

创建数据集以生成问题

170        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)

获取一系列问题和答案

172        qa = [dataset.get_qa() for _ in range(self.n_tests)]

只收集问题

174        questions = [p[0] for p in qa]

仅使用初始令牌创建张量

177        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])

移至设备

179        data = data.to(self.device)

已完成的序列数

182        finished = torch.zeros((len(questions),)).bool().to(self.device)

换行符的标记 ID-这标志着答案的结束

184        new_line = dataset.stoi['\n']

抽样结果

187        results = [p[0] for p in questions]

样本直至序列长度

190        for i in monit.iterate('Sample', self.seq_len - 1):

如果所有的序列都完成了,我们就跳过这个

192            if finished.sum() == len(finished):
193                continue

获取模型输出

196            output, *_ = self.model(data)

获取模型预测(贪婪)

198            output = output[-1].argmax(dim=-1)

找出哪些序列已完成

201            finished = finished | (output == new_line)

如果全部完成,则跳过

203            if finished.sum() == len(finished):
204                continue

用问题覆盖

207            for j, p in enumerate(questions):
208                if len(p) > i + 1:
209                    output[j] = dataset.stoi[p[i + 1]]

将下一个令牌添加到输入中

212            data = torch.cat([data, output[None, :]], dim=0)

获取抽样结果

215            for j, c in enumerate(output):
216                results[j] += dataset.itos[c]

丢弃结果中答案后的所有内容

219        results = [r.split('\n')[0] for r in results]

记录样本

222        res_sample = results[0].split(';')
223        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])

得到答案

226        results = [r.split('x==')[-1] for r in results]

计算正确答案的数量

229        correct = 0
230        for r, _qa in zip(results, qa):
231            if r == _qa[1]:
232                correct += 1

记录分数

235        tracker.save('score', correct / len(results))

训练数据加载器

238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):
243    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244                      batch_size=c.batch_size,
245                      collate_fn=transpose_batch,
246                      num_workers=4)

用于测试生成的问题的代码

249def _test():
253    dataset = ArithmeticDataset(256, 8, 10)
254
255    print(dataset.decode(dataset.get_packed_math_input()))

259if __name__ == '__main__':
260    _test()