これはジョルジュ・ハリク (@gharik) のコードに基づいています。
11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
24class ArithmeticDataset(Dataset):
seq_len
生成された数学問題のシーケンス長です。この長さまでできるだけ多くの問題を解きます。max_digits: はオペランド整数の最大桁数:n_sequences: はエポックあたりのシーケンス数34 def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
41 self.n_sequences = n_sequences
42 self.max_digits = max_digits
43 self.seq_len = seq_len
トークン ID を文字列に
45 self.itos = list(string.digits + 'xe =\n?+;')
文字からトークン ID へ
47 self.stoi = {c: i for i, c in enumerate(self.itos)}
n_digit
桁数の整数を生成します
49 @staticmethod
50 def make_int(n_digits: int):
54 res = 0
55 for i in range(n_digits):
56 d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57 res = res * 10 + d
58
59 return res
の作業を生成します。x + y
たとえば11+29
、生成する場合などです1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0
。
61 @staticmethod
62 def get_add_explanation(x: int, y: int):
69 carry = 0
70 e = 0
71 explanation = []
72 while x > 0 or y > 0 or carry > 0:
73 rx, ry = x % 10, y % 10
74 total = rx + ry + carry
75 explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76 x, y, carry = x // 10, y // 10, total // 10
77 e += 1
78
79 return ' '.join(explanation)
82 def make_add_problem(self):
86 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89 explanation = self.get_add_explanation(x, y)
90 return f"x={x}+{y}; {explanation} x=={x + y}\n"
算術問題を出して、答えを出してください。これは評価に使用されます。
92 def get_qa(self):
96 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99 return f'x={x}+{y};', f'{x + y}'
複数の問題を生成し、それらをシーケンスにまとめます。
101 def get_packed_math_input(self):
105 s_enc = []
106 while len(s_enc) <= self.seq_len:
107 s_part = self.make_add_problem()
108 s_part_enc = self.encode('?' + s_part)
109 s_enc = s_enc + s_part_enc
110 return s_enc
与えられた文字列をエンコードする
112 def encode(self, s: str):
116 return [self.stoi[c] for c in s]
トークン ID のリストをデコードする
118 def decode(self, arr: List[int]):
122 return ''.join([self.itos[c] for c in arr])
自己回帰モデリングの入力とターゲットのペアを取得
124 def __getitem__(self, idx: int):
128 s = torch.tensor(self.get_packed_math_input())
129 return s[:self.seq_len], s[1:self.seq_len + 1]
エポックあたりのシーケンス数
131 def __len__(self):
135 return self.n_sequences
138class ArithmeticAutoregression(NLPAutoRegressionConfigs):
オペランド整数あたりの最大桁数
143 max_digits: int = 4
エポックあたりのトレーニングシーケンスの数
145 train_sequences_per_epoch: int = 2 ** 12
トレーニングデータローダー
147 train_loader: DataLoader = 'arithmetic_train_loader'
評価中の問題の数
149 n_tests: int = 64
検証データセットは不要
151 validator = None
エポックごとに評価を実行する回数
153 inner_iterations = 4
ボキャブラリーのトークンの数
155 n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
157 @torch.no_grad()
158 def sample(self):
最初のエポックをスキップ
166 if self.training_loop.idx < 1:
167 return
問題を生成するためのデータセットの作成
170 dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
一連の問題と回答を入手
172 qa = [dataset.get_qa() for _ in range(self.n_tests)]
問題だけ集めよう
174 questions = [p[0] for p in qa]
最初のトークンのみでテンソルを作成
177 data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
デバイスに移動
179 data = data.to(self.device)
完了したシーケンスの数
182 finished = torch.zeros((len(questions),)).bool().to(self.device)
改行文字のトークンID-これで回答の最後になります
184 new_line = dataset.stoi['\n']
サンプル結果
187 results = [p[0] for p in questions]
シーケンス長までのサンプル
190 for i in monit.iterate('Sample', self.seq_len - 1):
すべてのシーケンスが完了したらこれをスキップします。
192 if finished.sum() == len(finished):
193 continue
モデル出力を取得
196 output, *_ = self.model(data)
モデル予測を取得 (欲張り)
198 output = output[-1].argmax(dim=-1)
どのシーケンスが終了したか調べる
201 finished = finished | (output == new_line)
すべて終了したらスキップ
203 if finished.sum() == len(finished):
204 continue
質問で上書き
207 for j, p in enumerate(questions):
208 if len(p) > i + 1:
209 output[j] = dataset.stoi[p[i + 1]]
次のトークンを入力に追加します
212 data = torch.cat([data, output[None, :]], dim=0)
サンプル結果を取得
215 for j, c in enumerate(output):
216 results[j] += dataset.itos[c]
結果の回答の後に続くものはすべて破棄してください
219 results = [r.split('\n')[0] for r in results]
サンプルをログに記録する
222 res_sample = results[0].split(';')
223 logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
答えをゲット
226 results = [r.split('x==')[-1] for r in results]
正解の数を数える
229 correct = 0
230 for r, _qa in zip(results, qa):
231 if r == _qa[1]:
232 correct += 1
スコアを記録する
235 tracker.save('score', correct / len(results))
トレーニングデータローダー
238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):
243 return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244 batch_size=c.batch_size,
245 collate_fn=transpose_batch,
246 num_workers=4)
生成された問題をテストするコード
249def _test():
253 dataset = ArithmeticDataset(256, 8, 10)
254
255 print(dataset.decode(dataset.get_packed_math_input()))
259if __name__ == '__main__':
260 _test()