これはジョルジュ・ハリク (@gharik) のコードに基づいています。

11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch

算術データセット

これにより、算術加算の問題と解法が生成されます。今のところ、追加を実装しただけです。

キャラクターレベルのトークン化に基づいています。

24class ArithmeticDataset(Dataset):
  • seq_len 生成された数学問題のシーケンス長です。この長さまでできるだけ多くの問題を解きます。max_digits: はオペランド整数の最大桁数:n_sequences: はエポックあたりのシーケンス数
34    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
41        self.n_sequences = n_sequences
42        self.max_digits = max_digits
43        self.seq_len = seq_len

トークン ID を文字列に

45        self.itos = list(string.digits + 'xe =\n?+;')

文字からトークン ID へ

47        self.stoi = {c: i for i, c in enumerate(self.itos)}

n_digit 桁数の整数を生成します

49    @staticmethod
50    def make_int(n_digits: int):
54        res = 0
55        for i in range(n_digits):
56            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57            res = res * 10 + d
58
59        return res

の作業を生成します。x + y たとえば11+29 、生成する場合などです1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0

61    @staticmethod
62    def get_add_explanation(x: int, y: int):
69        carry = 0
70        e = 0
71        explanation = []
72        while x > 0 or y > 0 or carry > 0:
73            rx, ry = x % 10, y % 10
74            total = rx + ry + carry
75            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76            x, y, carry = x // 10, y // 10, total // 10
77            e += 1
78
79        return ' '.join(explanation)

pre_explanation で問題を起こすかしないか

計算と解を含む算術加算問題を作成します。

82    def make_add_problem(self):
86        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89        explanation = self.get_add_explanation(x, y)
90        return f"x={x}+{y}; {explanation} x=={x + y}\n"

算術問題を出して、答えを出してください。これは評価に使用されます。

92    def get_qa(self):
96        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99        return f'x={x}+{y};', f'{x + y}'

複数の問題を生成し、それらをシーケンスにまとめます。

101    def get_packed_math_input(self):
105        s_enc = []
106        while len(s_enc) <= self.seq_len:
107            s_part = self.make_add_problem()
108            s_part_enc = self.encode('?' + s_part)
109            s_enc = s_enc + s_part_enc
110        return s_enc

与えられた文字列をエンコードする

112    def encode(self, s: str):
116        return [self.stoi[c] for c in s]

トークン ID のリストをデコードする

118    def decode(self, arr: List[int]):
122        return ''.join([self.itos[c] for c in arr])

自己回帰モデリングの入力とターゲットのペアを取得

124    def __getitem__(self, idx: int):
128        s = torch.tensor(self.get_packed_math_input())
129        return s[:self.seq_len], s[1:self.seq_len + 1]

エポックあたりのシーケンス数

131    def __len__(self):
135        return self.n_sequences

算術タスク実験構成

138class ArithmeticAutoregression(NLPAutoRegressionConfigs):

オペランド整数あたりの最大桁数

143    max_digits: int = 4

エポックあたりのトレーニングシーケンスの数

145    train_sequences_per_epoch: int = 2 ** 12

トレーニングデータローダー

147    train_loader: DataLoader = 'arithmetic_train_loader'

評価中の問題の数

149    n_tests: int = 64

検証データセットは不要

151    validator = None

エポックごとに評価を実行する回数

153    inner_iterations = 4

ボキャブラリーのトークンの数

155    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)

評価

サンプリング関数を使用して、一連の問題についてモデルを評価します。

157    @torch.no_grad()
158    def sample(self):

最初のエポックをスキップ

166        if self.training_loop.idx < 1:
167            return

問題を生成するためのデータセットの作成

170        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)

一連の問題と回答を入手

172        qa = [dataset.get_qa() for _ in range(self.n_tests)]

問題だけ集めよう

174        questions = [p[0] for p in qa]

最初のトークンのみでテンソルを作成

177        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])

デバイスに移動

179        data = data.to(self.device)

完了したシーケンスの数

182        finished = torch.zeros((len(questions),)).bool().to(self.device)

改行文字のトークンID-これで回答の最後になります

184        new_line = dataset.stoi['\n']

サンプル結果

187        results = [p[0] for p in questions]

シーケンス長までのサンプル

190        for i in monit.iterate('Sample', self.seq_len - 1):

すべてのシーケンスが完了したらこれをスキップします。

192            if finished.sum() == len(finished):
193                continue

モデル出力を取得

196            output, *_ = self.model(data)

モデル予測を取得 (欲張り)

198            output = output[-1].argmax(dim=-1)

どのシーケンスが終了したか調べる

201            finished = finished | (output == new_line)

すべて終了したらスキップ

203            if finished.sum() == len(finished):
204                continue

質問で上書き

207            for j, p in enumerate(questions):
208                if len(p) > i + 1:
209                    output[j] = dataset.stoi[p[i + 1]]

次のトークンを入力に追加します

212            data = torch.cat([data, output[None, :]], dim=0)

サンプル結果を取得

215            for j, c in enumerate(output):
216                results[j] += dataset.itos[c]

結果の回答の後に続くものはすべて破棄してください

219        results = [r.split('\n')[0] for r in results]

サンプルをログに記録する

222        res_sample = results[0].split(';')
223        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])

答えをゲット

226        results = [r.split('x==')[-1] for r in results]

正解の数を数える

229        correct = 0
230        for r, _qa in zip(results, qa):
231            if r == _qa[1]:
232                correct += 1

スコアを記録する

235        tracker.save('score', correct / len(results))

トレーニングデータローダー

238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):
243    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244                      batch_size=c.batch_size,
245                      collate_fn=transpose_batch,
246                      num_workers=4)

生成された問題をテストするコード

249def _test():
253    dataset = ArithmeticDataset(256, 8, 10)
254
255    print(dataset.decode(dataset.get_packed_math_input()))

259if __name__ == '__main__':
260    _test()