这是基于乔治·哈里克(@gharik)的代码。
11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
24class ArithmeticDataset(Dataset):
seq_len
是生成的数学问题的序列长度。我们尽可能多地填写问题,直到这个长度:max_digits: 是操作数中的最大位数整数:n_sequences: 是每个纪元的序列数34 def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
41 self.n_sequences = n_sequences
42 self.max_digits = max_digits
43 self.seq_len = seq_len
令牌 ID 转换为字符串
45 self.itos = list(string.digits + 'xe =\n?+;')
字符到令牌 ID
47 self.stoi = {c: i for i, c in enumerate(self.itos)}
生成一个包含位n_digit
数的整数
49 @staticmethod
50 def make_int(n_digits: int):
54 res = 0
55 for i in range(n_digits):
56 d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57 res = res * 10 + d
58
59 return res
生成的工作原理x + y
。例如,11+29
它会生成1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0
。
61 @staticmethod
62 def get_add_explanation(x: int, y: int):
69 carry = 0
70 e = 0
71 explanation = []
72 while x > 0 or y > 0 or carry > 0:
73 rx, ry = x % 10, y % 10
74 total = rx + ry + carry
75 explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76 x, y, carry = x // 10, y // 10, total // 10
77 e += 1
78
79 return ' '.join(explanation)
82 def make_add_problem(self):
86 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89 explanation = self.get_add_explanation(x, y)
90 return f"x={x}+{y}; {explanation} x=={x + y}\n"
获取算术问题和答案。这用于评估。
92 def get_qa(self):
96 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99 return f'x={x}+{y};', f'{x + y}'
生成多个问题并将它们打包成一个序列。
101 def get_packed_math_input(self):
105 s_enc = []
106 while len(s_enc) <= self.seq_len:
107 s_part = self.make_add_problem()
108 s_part_enc = self.encode('?' + s_part)
109 s_enc = s_enc + s_part_enc
110 return s_enc
对给定字符串进行编码
112 def encode(self, s: str):
116 return [self.stoi[c] for c in s]
解码令牌 ID 列表
118 def decode(self, arr: List[int]):
122 return ''.join([self.itos[c] for c in arr])
获取自动回归建模的输入和目标对
124 def __getitem__(self, idx: int):
128 s = torch.tensor(self.get_packed_math_input())
129 return s[:self.seq_len], s[1:self.seq_len + 1]
每个纪元的序列数
131 def __len__(self):
135 return self.n_sequences
138class ArithmeticAutoregression(NLPAutoRegressionConfigs):
每个操作数整数的最大位数
143 max_digits: int = 4
每个纪元的训练序列数
145 train_sequences_per_epoch: int = 2 ** 12
训练数据加载器
147 train_loader: DataLoader = 'arithmetic_train_loader'
评估中的问题数量
149 n_tests: int = 64
不需要验证数据集
151 validator = None
每个纪元运行评估的次数
153 inner_iterations = 4
词汇表中的代币数量
155 n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
157 @torch.no_grad()
158 def sample(self):
跳过第一个纪元
166 if self.training_loop.idx < 1:
167 return
创建数据集以生成问题
170 dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
获取一系列问题和答案
172 qa = [dataset.get_qa() for _ in range(self.n_tests)]
只收集问题
174 questions = [p[0] for p in qa]
仅使用初始令牌创建张量
177 data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
移至设备
179 data = data.to(self.device)
已完成的序列数
182 finished = torch.zeros((len(questions),)).bool().to(self.device)
换行符的标记 ID-这标志着答案的结束
184 new_line = dataset.stoi['\n']
抽样结果
187 results = [p[0] for p in questions]
样本直至序列长度
190 for i in monit.iterate('Sample', self.seq_len - 1):
如果所有的序列都完成了,我们就跳过这个
192 if finished.sum() == len(finished):
193 continue
获取模型输出
196 output, *_ = self.model(data)
获取模型预测(贪婪)
198 output = output[-1].argmax(dim=-1)
找出哪些序列已完成
201 finished = finished | (output == new_line)
如果全部完成,则跳过
203 if finished.sum() == len(finished):
204 continue
用问题覆盖
207 for j, p in enumerate(questions):
208 if len(p) > i + 1:
209 output[j] = dataset.stoi[p[i + 1]]
将下一个令牌添加到输入中
212 data = torch.cat([data, output[None, :]], dim=0)
获取抽样结果
215 for j, c in enumerate(output):
216 results[j] += dataset.itos[c]
丢弃结果中答案后的所有内容
219 results = [r.split('\n')[0] for r in results]
记录样本
222 res_sample = results[0].split(';')
223 logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
得到答案
226 results = [r.split('x==')[-1] for r in results]
计算正确答案的数量
229 correct = 0
230 for r, _qa in zip(results, qa):
231 if r == _qa[1]:
232 correct += 1
记录分数
235 tracker.save('score', correct / len(results))
训练数据加载器
238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):
243 return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244 batch_size=c.batch_size,
245 collate_fn=transpose_batch,
246 num_workers=4)
用于测试生成的问题的代码
249def _test():
253 dataset = ArithmeticDataset(256, 8, 10)
254
255 print(dataset.decode(dataset.get_packed_math_input()))
259if __name__ == '__main__':
260 _test()