11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs
එන්එල්පීවර්ගීකරණ කාර්ය පුහුණුව සඳහා මූලික වින්යාසයන් මෙයට ඇත. සියලුම ගුණාංග වින්යාසගත කළ හැකිය.
30class NLPClassificationConfigs(TrainValidConfigs):
ප්රශස්තකරණය
41 optimizer: torch.optim.Adam
පුහුණුඋපාංගය
43 device: torch.device = DeviceConfigs()
ස්වයංක්රීයප්රතිගාමී ආකෘතිය
46 model: Module
කණ්ඩායම්ප්රමාණය
48 batch_size: int = 16
අනුක්රමයේදිග, හෝ සන්දර්භය ප්රමාණය
50 seq_len: int = 512
වචනමාලාව
52 vocab: Vocab = 'ag_news'
වචනමාලාවේ ටෝකන් ගණන
54 n_tokens: int
පන්තිගණන
56 n_classes: int = 'ag_news'
ටෝකනයිසර්
58 tokenizer: Callable = 'character'
වරින්වර ආකෘති සුරැකීමට යන්න
61 is_save_models = True
පාඩුශ්රිතය
64 loss_func = nn.CrossEntropyLoss()
නිරවද්යතාශ්රිතය
66 accuracy = Accuracy()
ආදර්ශකාවැද්දීමේ ප්රමාණය
68 d_model: int = 512
ශ්රේණියේක්ලිපින්
70 grad_norm_clip: float = 1.0
පුහුණුදත්ත පැටවුම
73 train_loader: DataLoader = 'ag_news'
වලංගුදත්ත පැටවුම
75 valid_loader: DataLoader = 'ag_news'
ආදර්ශපරාමිතීන් සහ අනුක්රමික ලඝු-සටහන යන්න (එක් එක් එක් එක් එක් වරක්). මේවා ස්ථරයකට සාරාංශගත සංඛ්යාන වේ, නමුත් එය තවමත් ඉතා ගැඹුරු ජාල සඳහා බොහෝ දර්ශක වලට හේතු විය හැකිය.
80 is_log_model_params_grads: bool = False
ආදර්ශසක්රිය කිරීම් ලොග් කළ යුතුද යන්න (එක් එක් ඊපෝච් එකකට වරක්). මේවා ස්ථරයකට සාරාංශගත සංඛ්යාන වේ, නමුත් එය තවමත් ඉතා ගැඹුරු ජාල සඳහා බොහෝ දර්ශක වලට හේතු විය හැකිය.
85 is_log_model_activations: bool = False
87 def init(self):
ට්රැකර්වින්යාසයන් සකසන්න
92 tracker.set_scalar("accuracy.*", True)
93 tracker.set_scalar("loss.*", True)
මොඩියුලප්රතිදානයන් ලොග් කිරීමට කොක්කක් එක් කරන්න
95 hook_model_outputs(self.mode, self.model, 'model')
රාජ්යමොඩියුලයක් ලෙස නිරවද්යතාව එක් කරන්න. RNs සඳහා පුහුණුව සහ වලංගු කිරීම අතර රාජ්යයන් ගබඩා කිරීම අදහස් කරන බැවින් නම බොහෝ විට ව්යාකූල වේ. මෙය පුහුණුව සහ වලංගු කිරීම සඳහා නිරවද්යතා මෙට්රික් සංඛ්යාන වෙනම තබා ගනී.
100 self.state_modules = [self.accuracy]
102 def step(self, batch: any, batch_idx: BatchIndex):
උපාංගයවෙත දත්ත ගෙනයන්න
108 data, target = batch[0].to(self.device), batch[1].to(self.device)
පුහුණුප්රකාරයේදී ගෝලීය පියවර යාවත්කාලීන කරන්න (සැකසූ ටෝකන ගණන)
111 if self.mode.is_train:
112 tracker.add_global_step(data.shape[1])
ආකෘතිප්රතිදානයන් ග්රහණය කර ගත යුතුද යන්න
115 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
ආදර්ශප්රතිදානයන් ලබා ගන්න. ආර්එන්එස් භාවිතා කරන විට එය ප්රාන්ත සඳහා ටූල් එකක් නැවත ලබා දෙයි. මෙය තවම ක්රියාත්මක කර නැත 😜
119 output, *_ = self.model(data)
ගණනයකිරීම සහ ලොග් වීම
122 loss = self.loss_func(output, target)
123 tracker.add("loss.", loss)
ගණනයකිරීම සහ ලොග් කිරීමේ නිරවද්යතාවය
126 self.accuracy(output, target)
127 self.accuracy.track()
ආකෘතියපුහුණු කරන්න
130 if self.mode.is_train:
අනුක්රමිකගණනය කරන්න
132 loss.backward()
ක්ලිප්අනුක්රමික
134 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
ප්රශස්තිකරණපියවර ගන්න
136 self.optimizer.step()
සෑමයුගලයකම අවසාන කණ්ඩායමේ ආදර්ශ පරාමිතීන් සහ අනුක්රමික ලොග් කරන්න
138 if batch_idx.is_last and self.is_log_model_params_grads:
139 tracker.add('model', self.model)
අනුක්රමිකඉවත්
141 self.optimizer.zero_grad()
ලුහුබැඳඇති ප්රමිතික සුරකින්න
144 tracker.save()
147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):
153 optimizer = OptimizerConfigs()
154 optimizer.parameters = c.model.parameters()
155 optimizer.optimizer = 'Adam'
156 optimizer.d_model = c.d_model
157
158 return optimizer
මෙමඅත්හදා බැලීමේදී අපි චරිත මට්ටමේ ටෝකනයිසර් භාවිතා කරමු. සැකසීමෙන් ඔබට මාරු විය හැකිය,
'tokenizer': 'basic_english',
අත්හදාබැලීම ආරම්භ කිරීමේදී වින්යාස කිරීමේ ශබ්දකෝෂයේ.
161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():
176 from torchtext.data import get_tokenizer
177 return get_tokenizer('basic_english')
180def character_tokenizer(x: str):
184 return list(x)
අක්ෂරමට්ටමේ ටෝකනයිසර් වින්යාසය
187@option(NLPClassificationConfigs.tokenizer)
188def character():
192 return character_tokenizer
ටෝකනගණන ලබා ගන්න
195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):
200 return len(c.vocab) + 2
203class CollateFunc:
tokenizer
ටෝකනයිසර් ශ්රිතය වේ vocab
යනු වචන මාලාවයි seq_len
අනුක්රමයේ දිග වේ padding_token
යනු පෙළ දිගට වඩා විශාල seq_len
වන විට පෑඩින් කිරීම සඳහා භාවිතා කරන ටෝකනය වේ classifier_token
යනු ආදානය අවසානයේ අප විසින් සකස් කරන ලද [CLS]
ටෝකනයයි208 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
216 self.classifier_token = classifier_token
217 self.padding_token = padding_token
218 self.seq_len = seq_len
219 self.vocab = vocab
220 self.tokenizer = tokenizer
batch
විසින් එකතු කරන ලද දත්ත කාණ්ඩයයි DataLoader
222 def __call__(self, batch):
ආදානදත්ත ටෙන්සරය, සමඟ ආරම්භ කර ඇත padding_token
228 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
හිස්ලේබල් ටෙන්සර්
230 labels = torch.zeros(len(batch), dtype=torch.long)
සාම්පලහරහා ලූප
233 for (i, (_label, _text)) in enumerate(batch):
ලේබලයසකසන්න
235 labels[i] = int(_label) - 1
ආදානපෙළ ටෝකනයිස් කරන්න
237 _text = [self.vocab[token] for token in self.tokenizer(_text)]
දක්වාටන්ක කරන්න seq_len
239 _text = _text[:self.seq_len]
දත්තසම්ප්රේෂණය කර එකතු කරන්න
241 data[:len(_text), i] = data.new_tensor(_text)
අනුපිළිවෙලෙහිඅවසාන ටෝකනය සකසන්න [CLS]
244 data[-1, :] = self.classifier_token
247 return data, labels
මෙයAG පුවත් දත්ත සමුදාය පටවන අතර n_classes
, සඳහා අගයන් සකසා ඇත, vocab
, train_loader
, සහ valid_loader
.
250@option([NLPClassificationConfigs.n_classes,
251 NLPClassificationConfigs.vocab,
252 NLPClassificationConfigs.train_loader,
253 NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):
පුහුණුවසහ වලංගු කිරීමේ දත්ත කට්ටල ලබා ගන්න
263 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
මතකයවෙත දත්ත පූරණය කරන්න
266 with monit.section('Load data'):
267 from labml_nn.utils import MapStyleDataset
සිතියම් ආකාරයේ දත්ත කට්ටල සාදන්න
270 train, valid = MapStyleDataset(train), MapStyleDataset(valid)
ටෝකනයිසර්ලබා ගන්න
273 tokenizer = c.tokenizer
කවුන්ටරයක්සාදන්න
276 counter = Counter()
පුහුණුදත්ත කට්ටලයෙන් ටෝකන එකතු කරන්න
278 for (label, line) in train:
279 counter.update(tokenizer(line))
වලංගුදත්ත කට්ටලයෙන් ටෝකන එකතු කරන්න
281 for (label, line) in valid:
282 counter.update(tokenizer(line))
වචනමාලාව සාදන්න
284 vocab = torchtext.vocab.vocab(counter, min_freq=1)
පුහුණුදත්ත පැටවුම සාදන්න
287 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
වලංගුදත්ත පැටවුම සාදන්න
290 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
ආපසු n_classes
vocab
, train_loader
, සහ valid_loader
294 return 4, vocab, train_loader, valid_loader