これはジョルジュ・ハリク (@gharik) のコードに基づいています。
11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch24class ArithmeticDataset(Dataset):seq_len
生成された数学問題のシーケンス長です。この長さまでできるだけ多くの問題を解きます。max_digits: はオペランド整数の最大桁数:n_sequences: はエポックあたりのシーケンス数34    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):41        self.n_sequences = n_sequences
42        self.max_digits = max_digits
43        self.seq_len = seq_lenトークン ID を文字列に
45        self.itos = list(string.digits + 'xe =\n?+;')文字からトークン ID へ
47        self.stoi = {c: i for i, c in enumerate(self.itos)}n_digit
桁数の整数を生成します
49    @staticmethod
50    def make_int(n_digits: int):54        res = 0
55        for i in range(n_digits):
56            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57            res = res * 10 + d
58
59        return resの作業を生成します。x + y
たとえば11+29
、生成する場合などです1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0
。
61    @staticmethod
62    def get_add_explanation(x: int, y: int):69        carry = 0
70        e = 0
71        explanation = []
72        while x > 0 or y > 0 or carry > 0:
73            rx, ry = x % 10, y % 10
74            total = rx + ry + carry
75            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76            x, y, carry = x // 10, y // 10, total // 10
77            e += 1
78
79        return ' '.join(explanation)82    def make_add_problem(self):86        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89        explanation = self.get_add_explanation(x, y)
90        return f"x={x}+{y}; {explanation} x=={x + y}\n"算術問題を出して、答えを出してください。これは評価に使用されます。
92    def get_qa(self):96        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99        return f'x={x}+{y};', f'{x + y}'複数の問題を生成し、それらをシーケンスにまとめます。
101    def get_packed_math_input(self):105        s_enc = []
106        while len(s_enc) <= self.seq_len:
107            s_part = self.make_add_problem()
108            s_part_enc = self.encode('?' + s_part)
109            s_enc = s_enc + s_part_enc
110        return s_enc与えられた文字列をエンコードする
112    def encode(self, s: str):116        return [self.stoi[c] for c in s]トークン ID のリストをデコードする
118    def decode(self, arr: List[int]):122        return ''.join([self.itos[c] for c in arr])自己回帰モデリングの入力とターゲットのペアを取得
124    def __getitem__(self, idx: int):128        s = torch.tensor(self.get_packed_math_input())
129        return s[:self.seq_len], s[1:self.seq_len + 1]エポックあたりのシーケンス数
131    def __len__(self):135        return self.n_sequences138class ArithmeticAutoregression(NLPAutoRegressionConfigs):オペランド整数あたりの最大桁数
143    max_digits: int = 4エポックあたりのトレーニングシーケンスの数
145    train_sequences_per_epoch: int = 2 ** 12トレーニングデータローダー
147    train_loader: DataLoader = 'arithmetic_train_loader'評価中の問題の数
149    n_tests: int = 64検証データセットは不要
151    validator = Noneエポックごとに評価を実行する回数
153    inner_iterations = 4ボキャブラリーのトークンの数
155    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)157    @torch.no_grad()
158    def sample(self):最初のエポックをスキップ
166        if self.training_loop.idx < 1:
167            return問題を生成するためのデータセットの作成
170        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)一連の問題と回答を入手
172        qa = [dataset.get_qa() for _ in range(self.n_tests)]問題だけ集めよう
174        questions = [p[0] for p in qa]最初のトークンのみでテンソルを作成
177        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])デバイスに移動
179        data = data.to(self.device)完了したシーケンスの数
182        finished = torch.zeros((len(questions),)).bool().to(self.device)改行文字のトークンID-これで回答の最後になります
184        new_line = dataset.stoi['\n']サンプル結果
187        results = [p[0] for p in questions]シーケンス長までのサンプル
190        for i in monit.iterate('Sample', self.seq_len - 1):すべてのシーケンスが完了したらこれをスキップします。
192            if finished.sum() == len(finished):
193                continueモデル出力を取得
196            output, *_ = self.model(data)モデル予測を取得 (欲張り)
198            output = output[-1].argmax(dim=-1)どのシーケンスが終了したか調べる
201            finished = finished | (output == new_line)すべて終了したらスキップ
203            if finished.sum() == len(finished):
204                continue質問で上書き
207            for j, p in enumerate(questions):
208                if len(p) > i + 1:
209                    output[j] = dataset.stoi[p[i + 1]]次のトークンを入力に追加します
212            data = torch.cat([data, output[None, :]], dim=0)サンプル結果を取得
215            for j, c in enumerate(output):
216                results[j] += dataset.itos[c]結果の回答の後に続くものはすべて破棄してください
219        results = [r.split('\n')[0] for r in results]サンプルをログに記録する
222        res_sample = results[0].split(';')
223        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])答えをゲット
226        results = [r.split('x==')[-1] for r in results]正解の数を数える
229        correct = 0
230        for r, _qa in zip(results, qa):
231            if r == _qa[1]:
232                correct += 1スコアを記録する
235        tracker.save('score', correct / len(results))トレーニングデータローダー
238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):243    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244                      batch_size=c.batch_size,
245                      collate_fn=transpose_batch,
246                      num_workers=4)生成された問題をテストするコード
249def _test():253    dataset = ArithmeticDataset(256, 8, 10)
254
255    print(dataset.decode(dataset.get_packed_math_input()))259if __name__ == '__main__':
260    _test()