මෙය ජෝර්ජස් හාරික් (@gharik)විසින් කරන ලද කේතය මත පදනම් වේ.

11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch

අංකගණිත දත්ත කට්ටලය

මෙයඅංක ගණිතමය එකතු කිරීමේ ගැටළු සහ ක්රියාකාරිත්වය සමඟ විසඳුම් නිර්මාණය කරයි. අපි මෙතෙක් ක්රියාත්මක කර ඇත්තේ එකතු කිරීම පමණි.

එයපදනම් වී ඇත්තේ චරිත මට්ටමේ ටෝකනීකරණය මත ය.

24class ArithmeticDataset(Dataset):
  • seq_len යනු ජනනය කරන ලද ගණිත ගැටළු වල අනුක්රමික දිගයි. මෙම දිග දක්වා අපි හැකි තරම් ගැටළු පුරවන්නෙමු: max_digits: යනු ඔපෙරන්ඩ් සංඛ්‍යාවේ උපරිම ඉලක්කම් සංඛ්‍යාව වේ: n_sequences: යනු එපෝච් එකකට අනුක්‍රම ගණන
වේ
34    def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
41        self.n_sequences = n_sequences
42        self.max_digits = max_digits
43        self.seq_len = seq_len

නූලටටෝකන් හැඳුනුම්පත

45        self.itos = list(string.digits + 'xe =\n?+;')

අක්ෂරටෝකන් හැඳුනුම්පතට

47        self.stoi = {c: i for i, c in enumerate(self.itos)}

ඉලක්කම් n_digit ගණන සහිත පූර්ණ සංඛ්යාවක් ජනනය කරයි

49    @staticmethod
50    def make_int(n_digits: int):
54        res = 0
55        for i in range(n_digits):
56            d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57            res = res * 10 + d
58
59        return res

සඳහාක්රියාකාරිත්වය ජනනය කරයි x + y . උදාහරණයක් ලෙස 11+29 එය ජනනය කිරීම සඳහා 1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0 .

61    @staticmethod
62    def get_add_explanation(x: int, y: int):
69        carry = 0
70        e = 0
71        explanation = []
72        while x > 0 or y > 0 or carry > 0:
73            rx, ry = x % 10, y % 10
74            total = rx + ry + carry
75            explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76            x, y, carry = x // 10, y // 10, total // 10
77            e += 1
78
79        return ' '.join(explanation)

පූර්වපැහැදිලි කිරීමක් සමඟ ගැටළුවක් ඇති කරන්න හෝ නැත

workingsහා පිළිතුරු සමග අංක ගණිතමය එකතු කිරීමේ ගැටලුවක් නිර්මාණය කරයි.

82    def make_add_problem(self):
86        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89        explanation = self.get_add_explanation(x, y)
90        return f"x={x}+{y}; {explanation} x=={x + y}\n"

අංකගණිතමය ගැටළුව සහ පිළිතුරු ලබා ගන්න. මෙය ඇගයීම සඳහා භාවිතා වේ.

92    def get_qa(self):
96        x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97        y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99        return f'x={x}+{y};', f'{x + y}'

බහුගැටළු ජනනය කර ඒවා අනුපිළිවෙලකට ඇසුරුම් කරන්න.

101    def get_packed_math_input(self):
105        s_enc = []
106        while len(s_enc) <= self.seq_len:
107            s_part = self.make_add_problem()
108            s_part_enc = self.encode('?' + s_part)
109            s_enc = s_enc + s_part_enc
110        return s_enc

දීඇති නූලක් කේතනය කරන්න

112    def encode(self, s: str):
116        return [self.stoi[c] for c in s]

ටෝකන්හැඳුනුම්පත් ලැයිස්තුවක් විකේතනය කරන්න

118    def decode(self, arr: List[int]):
122        return ''.join([self.itos[c] for c in arr])

ස්වයංක්රීයප්රතිගාමී ආකෘති නිර්මාණය සඳහා ආදාන සහ ඉලක්ක යුගලයක් ලබා ගන්න

124    def __getitem__(self, idx: int):
128        s = torch.tensor(self.get_packed_math_input())
129        return s[:self.seq_len], s[1:self.seq_len + 1]

එපෝච්එකකට අනුපිළිවෙලවල් ගණන

131    def __len__(self):
135        return self.n_sequences

අංකගණිත කාර්ය අත්හදා බැලීමේ වින්යාසයන්

138class ArithmeticAutoregression(NLPAutoRegressionConfigs):

ක්රියාකාරීපූර්ණ සංඛ්යාවකට උපරිම ඉලක්කම් ගණන

143    max_digits: int = 4

එපෝච්එකකට පුහුණු අනුපිළිවෙල ගණන

145    train_sequences_per_epoch: int = 2 ** 12

පුහුණුදත්ත පැටවුම

147    train_loader: DataLoader = 'arithmetic_train_loader'

ඇගයීමේගැටළු ගණන

149    n_tests: int = 64

වලංගුදත්ත කට්ටලයක් අවශ්ය නොවේ

151    validator = None

එපෝච්එකකට ඇගයීම් ක්රියාත්මක කිරීමට වාර ගණන

153    inner_iterations = 4

වචනමාලාවේ ටෝකන ගණන

155    n_tokens = len(ArithmeticDataset(1, 1, 1).itos)

ඇගයීම

ගැටළුසමූහයක් මත ආකෘතිය ඇගයීම සඳහා අපි නියැදි ශ්රිතය භාවිතා කරමු

157    @torch.no_grad()
158    def sample(self):

පළමුඑපෝච් එකේ මඟ හරින්න

166        if self.training_loop.idx < 1:
167            return

ගැටළුජනනය කිරීම සඳහා දත්ත කට්ටලයක් සාදන්න

170        dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)

ගැටළුසහ පිළිතුරු සමූහයක් ලබා ගන්න

172        qa = [dataset.get_qa() for _ in range(self.n_tests)]

ගැටළුපමණක් එකතු කරන්න

174        questions = [p[0] for p in qa]

ආරම්භකටෝකනය පමණක් සහිත ටෙන්සරයක් සාදන්න

177        data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])

උපාංගයවෙත ගෙන යන්න

179        data = data.to(self.device)

සම්පූර්ණකර ඇති අනුක්රම ගණන

182        finished = torch.zeros((len(questions),)).bool().to(self.device)

නවරේඛා අක්ෂරයෙහි ටෝකන් හැඳුනුම්පත - මෙය පිළිතුරේ අවසානය සලකුණු කරයි

184        new_line = dataset.stoi['\n']

නියැදිප්රති. ල

187        results = [p[0] for p in questions]

අනුක්රමිකදිග දක්වා නියැදිය

190        for i in monit.iterate('Sample', self.seq_len - 1):

සියලුමඅනුක්රමයන් සම්පූර්ණ කර ඇත්නම් අපි මෙය මඟ හරිමු

192            if finished.sum() == len(finished):
193                continue

ආදර්ශප්රතිදානය ලබා ගන්න

196            output, *_ = self.model(data)

ආදර්ශඅනාවැකිය ලබා ගන්න (කෑදර)

198            output = output[-1].argmax(dim=-1)

කුමනඅනුපිළිවෙලවල් අවසන් කර ඇත්දැයි සොයා ගන්න

201            finished = finished | (output == new_line)

සියල්ලඅවසන් වී ඇත්නම් මඟ හරින්න

203            if finished.sum() == len(finished):
204                continue

ප්රශ්නයසමඟ අභිබවා යන්න

207            for j, p in enumerate(questions):
208                if len(p) > i + 1:
209                    output[j] = dataset.stoi[p[i + 1]]

ආදානයටඊළඟ ටෝකනය එක් කරන්න

212            data = torch.cat([data, output[None, :]], dim=0)

නියැදිප්රතිඵල ලබා ගන්න

215            for j, c in enumerate(output):
216                results[j] += dataset.itos[c]

ප්රතිඵලවලපිළිතුරෙන් පසු සියල්ල ඉවතලන්න

219        results = [r.split('\n')[0] for r in results]

නියැදියක්ලොග් කරන්න

222        res_sample = results[0].split(';')
223        logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])

පිළිතුරුලබා ගන්න

226        results = [r.split('x==')[-1] for r in results]

නිවැරදිපිළිතුරු ගණන ගණන් කරන්න

229        correct = 0
230        for r, _qa in zip(results, qa):
231            if r == _qa[1]:
232                correct += 1

ලකුණුලොග් කරන්න

235        tracker.save('score', correct / len(results))

පුහුණුදත්ත පැටවුම

238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):
243    return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244                      batch_size=c.batch_size,
245                      collate_fn=transpose_batch,
246                      num_workers=4)

ජනනයකරන ලද ගැටළු පරීක්ෂා කිරීම සඳහා කේතය

249def _test():
253    dataset = ArithmeticDataset(256, 8, 10)
254
255    print(dataset.decode(dataset.get_packed_math_input()))

259if __name__ == '__main__':
260    _test()