මෙය ජෝර්ජස් හාරික් (@gharik)විසින් කරන ලද කේතය මත පදනම් වේ.
11import random
12import string
13from typing import List
14
15import torch
16from labml.logger import Text
17from torch.utils.data import DataLoader, Dataset
18
19from labml import monit, logger, tracker
20from labml.configs import option
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
මෙයඅංක ගණිතමය එකතු කිරීමේ ගැටළු සහ ක්රියාකාරිත්වය සමඟ විසඳුම් නිර්මාණය කරයි. අපි මෙතෙක් ක්රියාත්මක කර ඇත්තේ එකතු කිරීම පමණි.
එයපදනම් වී ඇත්තේ චරිත මට්ටමේ ටෝකනීකරණය මත ය.
24class ArithmeticDataset(Dataset):
seq_len
යනු ජනනය කරන ලද ගණිත ගැටළු වල අනුක්රමික දිගයි. මෙම දිග දක්වා අපි හැකි තරම් ගැටළු පුරවන්නෙමු: max_digits: යනු ඔපෙරන්ඩ් සංඛ්යාවේ උපරිම ඉලක්කම් සංඛ්යාව වේ: n_sequences: යනු එපෝච් එකකට අනුක්රම ගණන34 def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
41 self.n_sequences = n_sequences
42 self.max_digits = max_digits
43 self.seq_len = seq_len
නූලටටෝකන් හැඳුනුම්පත
45 self.itos = list(string.digits + 'xe =\n?+;')
අක්ෂරටෝකන් හැඳුනුම්පතට
47 self.stoi = {c: i for i, c in enumerate(self.itos)}
ඉලක්කම් n_digit
ගණන සහිත පූර්ණ සංඛ්යාවක් ජනනය කරයි
49 @staticmethod
50 def make_int(n_digits: int):
54 res = 0
55 for i in range(n_digits):
56 d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
57 res = res * 10 + d
58
59 return res
සඳහාක්රියාකාරිත්වය ජනනය කරයි x + y
. උදාහරණයක් ලෙස 11+29
එය ජනනය කිරීම සඳහා 1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0
.
61 @staticmethod
62 def get_add_explanation(x: int, y: int):
69 carry = 0
70 e = 0
71 explanation = []
72 while x > 0 or y > 0 or carry > 0:
73 rx, ry = x % 10, y % 10
74 total = rx + ry + carry
75 explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
76 x, y, carry = x // 10, y // 10, total // 10
77 e += 1
78
79 return ' '.join(explanation)
පූර්වපැහැදිලි කිරීමක් සමඟ ගැටළුවක් ඇති කරන්න හෝ නැත
workingsහා පිළිතුරු සමග අංක ගණිතමය එකතු කිරීමේ ගැටලුවක් නිර්මාණය කරයි.
82 def make_add_problem(self):
86 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
87 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
88
89 explanation = self.get_add_explanation(x, y)
90 return f"x={x}+{y}; {explanation} x=={x + y}\n"
අංකගණිතමය ගැටළුව සහ පිළිතුරු ලබා ගන්න. මෙය ඇගයීම සඳහා භාවිතා වේ.
92 def get_qa(self):
96 x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
97 y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
98
99 return f'x={x}+{y};', f'{x + y}'
බහුගැටළු ජනනය කර ඒවා අනුපිළිවෙලකට ඇසුරුම් කරන්න.
101 def get_packed_math_input(self):
105 s_enc = []
106 while len(s_enc) <= self.seq_len:
107 s_part = self.make_add_problem()
108 s_part_enc = self.encode('?' + s_part)
109 s_enc = s_enc + s_part_enc
110 return s_enc
දීඇති නූලක් කේතනය කරන්න
112 def encode(self, s: str):
116 return [self.stoi[c] for c in s]
ටෝකන්හැඳුනුම්පත් ලැයිස්තුවක් විකේතනය කරන්න
118 def decode(self, arr: List[int]):
122 return ''.join([self.itos[c] for c in arr])
ස්වයංක්රීයප්රතිගාමී ආකෘති නිර්මාණය සඳහා ආදාන සහ ඉලක්ක යුගලයක් ලබා ගන්න
124 def __getitem__(self, idx: int):
128 s = torch.tensor(self.get_packed_math_input())
129 return s[:self.seq_len], s[1:self.seq_len + 1]
එපෝච්එකකට අනුපිළිවෙලවල් ගණන
131 def __len__(self):
135 return self.n_sequences
138class ArithmeticAutoregression(NLPAutoRegressionConfigs):
ක්රියාකාරීපූර්ණ සංඛ්යාවකට උපරිම ඉලක්කම් ගණන
143 max_digits: int = 4
එපෝච්එකකට පුහුණු අනුපිළිවෙල ගණන
145 train_sequences_per_epoch: int = 2 ** 12
පුහුණුදත්ත පැටවුම
147 train_loader: DataLoader = 'arithmetic_train_loader'
ඇගයීමේගැටළු ගණන
149 n_tests: int = 64
වලංගුදත්ත කට්ටලයක් අවශ්ය නොවේ
151 validator = None
එපෝච්එකකට ඇගයීම් ක්රියාත්මක කිරීමට වාර ගණන
153 inner_iterations = 4
වචනමාලාවේ ටෝකන ගණන
155 n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
157 @torch.no_grad()
158 def sample(self):
පළමුඑපෝච් එකේ මඟ හරින්න
166 if self.training_loop.idx < 1:
167 return
ගැටළුජනනය කිරීම සඳහා දත්ත කට්ටලයක් සාදන්න
170 dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
ගැටළුසහ පිළිතුරු සමූහයක් ලබා ගන්න
172 qa = [dataset.get_qa() for _ in range(self.n_tests)]
ගැටළුපමණක් එකතු කරන්න
174 questions = [p[0] for p in qa]
ආරම්භකටෝකනය පමණක් සහිත ටෙන්සරයක් සාදන්න
177 data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
උපාංගයවෙත ගෙන යන්න
179 data = data.to(self.device)
සම්පූර්ණකර ඇති අනුක්රම ගණන
182 finished = torch.zeros((len(questions),)).bool().to(self.device)
නවරේඛා අක්ෂරයෙහි ටෝකන් හැඳුනුම්පත - මෙය පිළිතුරේ අවසානය සලකුණු කරයි
184 new_line = dataset.stoi['\n']
නියැදිප්රති. ල
187 results = [p[0] for p in questions]
අනුක්රමිකදිග දක්වා නියැදිය
190 for i in monit.iterate('Sample', self.seq_len - 1):
සියලුමඅනුක්රමයන් සම්පූර්ණ කර ඇත්නම් අපි මෙය මඟ හරිමු
192 if finished.sum() == len(finished):
193 continue
ආදර්ශප්රතිදානය ලබා ගන්න
196 output, *_ = self.model(data)
ආදර්ශඅනාවැකිය ලබා ගන්න (කෑදර)
198 output = output[-1].argmax(dim=-1)
කුමනඅනුපිළිවෙලවල් අවසන් කර ඇත්දැයි සොයා ගන්න
201 finished = finished | (output == new_line)
සියල්ලඅවසන් වී ඇත්නම් මඟ හරින්න
203 if finished.sum() == len(finished):
204 continue
ප්රශ්නයසමඟ අභිබවා යන්න
207 for j, p in enumerate(questions):
208 if len(p) > i + 1:
209 output[j] = dataset.stoi[p[i + 1]]
ආදානයටඊළඟ ටෝකනය එක් කරන්න
212 data = torch.cat([data, output[None, :]], dim=0)
නියැදිප්රතිඵල ලබා ගන්න
215 for j, c in enumerate(output):
216 results[j] += dataset.itos[c]
ප්රතිඵලවලපිළිතුරෙන් පසු සියල්ල ඉවතලන්න
219 results = [r.split('\n')[0] for r in results]
නියැදියක්ලොග් කරන්න
222 res_sample = results[0].split(';')
223 logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
පිළිතුරුලබා ගන්න
226 results = [r.split('x==')[-1] for r in results]
නිවැරදිපිළිතුරු ගණන ගණන් කරන්න
229 correct = 0
230 for r, _qa in zip(results, qa):
231 if r == _qa[1]:
232 correct += 1
ලකුණුලොග් කරන්න
235 tracker.save('score', correct / len(results))
පුහුණුදත්ත පැටවුම
238@option(ArithmeticAutoregression.train_loader)
239def arithmetic_train_loader(c: ArithmeticAutoregression):
243 return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
244 batch_size=c.batch_size,
245 collate_fn=transpose_batch,
246 num_workers=4)
ජනනයකරන ලද ගැටළු පරීක්ෂා කිරීම සඳහා කේතය
249def _test():
253 dataset = ArithmeticDataset(256, 8, 10)
254
255 print(dataset.decode(dataset.get_packed_math_input()))
259if __name__ == '__main__':
260 _test()