වර්ගීකරණයසඳහා NLP ආකෘති පුහුණුකරු

11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs

පුහුණුකරුමානකරණ

එන්එල්පීවර්ගීකරණ කාර්ය පුහුණුව සඳහා මූලික වින්යාසයන් මෙයට ඇත. සියලුම ගුණාංග වින්යාසගත කළ හැකිය.

30class NLPClassificationConfigs(TrainValidConfigs):

ප්රශස්තකරණය

41    optimizer: torch.optim.Adam

පුහුණුඋපාංගය

43    device: torch.device = DeviceConfigs()

ස්වයංක්රීයප්රතිගාමී ආකෘතිය

46    model: Module

කණ්ඩායම්ප්රමාණය

48    batch_size: int = 16

අනුක්රමයේදිග, හෝ සන්දර්භය ප්රමාණය

50    seq_len: int = 512

වචනමාලාව

52    vocab: Vocab = 'ag_news'

වචනමාලාවේ ටෝකන් ගණන

54    n_tokens: int

පන්තිගණන

56    n_classes: int = 'ag_news'

ටෝකනයිසර්

58    tokenizer: Callable = 'character'

වරින්වර ආකෘති සුරැකීමට යන්න

61    is_save_models = True

පාඩුශ්රිතය

64    loss_func = nn.CrossEntropyLoss()

නිරවද්යතාශ්රිතය

66    accuracy = Accuracy()

ආදර්ශකාවැද්දීමේ ප්රමාණය

68    d_model: int = 512

ශ්රේණියේක්ලිපින්

70    grad_norm_clip: float = 1.0

පුහුණුදත්ත පැටවුම

73    train_loader: DataLoader = 'ag_news'

වලංගුදත්ත පැටවුම

75    valid_loader: DataLoader = 'ag_news'

ආදර්ශපරාමිතීන් සහ අනුක්රමික ලඝු-සටහන යන්න (එක් එක් එක් එක් එක් වරක්). මේවා ස්ථරයකට සාරාංශගත සංඛ්යාන වේ, නමුත් එය තවමත් ඉතා ගැඹුරු ජාල සඳහා බොහෝ දර්ශක වලට හේතු විය හැකිය.

80    is_log_model_params_grads: bool = False

ආදර්ශසක්රිය කිරීම් ලොග් කළ යුතුද යන්න (එක් එක් ඊපෝච් එකකට වරක්). මේවා ස්ථරයකට සාරාංශගත සංඛ්යාන වේ, නමුත් එය තවමත් ඉතා ගැඹුරු ජාල සඳහා බොහෝ දර්ශක වලට හේතු විය හැකිය.

85    is_log_model_activations: bool = False

ආරම්භකකරණය

87    def init(self):

ට්රැකර්වින්යාසයන් සකසන්න

92        tracker.set_scalar("accuracy.*", True)
93        tracker.set_scalar("loss.*", True)

මොඩියුලප්රතිදානයන් ලොග් කිරීමට කොක්කක් එක් කරන්න

95        hook_model_outputs(self.mode, self.model, 'model')

රාජ්යමොඩියුලයක් ලෙස නිරවද්යතාව එක් කරන්න. RNs සඳහා පුහුණුව සහ වලංගු කිරීම අතර රාජ්යයන් ගබඩා කිරීම අදහස් කරන බැවින් නම බොහෝ විට ව්යාකූල වේ. මෙය පුහුණුව සහ වලංගු කිරීම සඳහා නිරවද්යතා මෙට්රික් සංඛ්යාන වෙනම තබා ගනී.

100        self.state_modules = [self.accuracy]

පුහුණුවහෝ වලංගු කිරීමේ පියවර

102    def step(self, batch: any, batch_idx: BatchIndex):

උපාංගයවෙත දත්ත ගෙනයන්න

108        data, target = batch[0].to(self.device), batch[1].to(self.device)

පුහුණුප්රකාරයේදී ගෝලීය පියවර යාවත්කාලීන කරන්න (සැකසූ ටෝකන ගණන)

111        if self.mode.is_train:
112            tracker.add_global_step(data.shape[1])

ආකෘතිප්රතිදානයන් ග්රහණය කර ගත යුතුද යන්න

115        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):

ආදර්ශප්රතිදානයන් ලබා ගන්න. ආර්එන්එස් භාවිතා කරන විට එය ප්රාන්ත සඳහා ටූල් එකක් නැවත ලබා දෙයි. මෙය තවම ක්රියාත්මක කර නැත 😜

119            output, *_ = self.model(data)

ගණනයකිරීම සහ ලොග් වීම

122        loss = self.loss_func(output, target)
123        tracker.add("loss.", loss)

ගණනයකිරීම සහ ලොග් කිරීමේ නිරවද්යතාවය

126        self.accuracy(output, target)
127        self.accuracy.track()

ආකෘතියපුහුණු කරන්න

130        if self.mode.is_train:

අනුක්රමිකගණනය කරන්න

132            loss.backward()

ක්ලිප්අනුක්රමික

134            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

ප්රශස්තිකරණපියවර ගන්න

136            self.optimizer.step()

සෑමයුගලයකම අවසාන කණ්ඩායමේ ආදර්ශ පරාමිතීන් සහ අනුක්රමික ලොග් කරන්න

138            if batch_idx.is_last and self.is_log_model_params_grads:
139                tracker.add('model', self.model)

අනුක්රමිකඉවත්

141            self.optimizer.zero_grad()

ලුහුබැඳඇති ප්රමිතික සුරකින්න

144        tracker.save()
147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):
153    optimizer = OptimizerConfigs()
154    optimizer.parameters = c.model.parameters()
155    optimizer.optimizer = 'Adam'
156    optimizer.d_model = c.d_model
157
158    return optimizer

මූලිකඉංග්රීසි ටෝකනයිසර්

මෙමඅත්හදා බැලීමේදී අපි චරිත මට්ටමේ ටෝකනයිසර් භාවිතා කරමු. සැකසීමෙන් ඔබට මාරු විය හැකිය,

'tokenizer': 'basic_english',

අත්හදාබැලීම ආරම්භ කිරීමේදී වින්යාස කිරීමේ ශබ්දකෝෂයේ.

161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():
176    from torchtext.data import get_tokenizer
177    return get_tokenizer('basic_english')

අක්ෂරමට්ටමේ ටෝකනයිසර්

180def character_tokenizer(x: str):
184    return list(x)

අක්ෂරමට්ටමේ ටෝකනයිසර් වින්යාසය

187@option(NLPClassificationConfigs.tokenizer)
188def character():
192    return character_tokenizer

ටෝකනගණන ලබා ගන්න

195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):
200    return len(c.vocab) + 2

දත්තකාණ්ඩවලට පැටවීමේ කාර්යය

203class CollateFunc:
  • tokenizer ටෝකනයිසර් ශ්රිතය වේ
  • vocab යනු වචන මාලාවයි
  • seq_len අනුක්රමයේ දිග වේ
  • padding_token යනු පෙළ දිගට වඩා විශාල seq_len වන විට පෑඩින් කිරීම සඳහා භාවිතා කරන ටෝකනය වේ
  • classifier_token යනු ආදානය අවසානයේ අප විසින් සකස් කරන ලද [CLS] ටෝකනයයි
208    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
216        self.classifier_token = classifier_token
217        self.padding_token = padding_token
218        self.seq_len = seq_len
219        self.vocab = vocab
220        self.tokenizer = tokenizer
  • batch විසින් එකතු කරන ලද දත්ත කාණ්ඩයයි DataLoader
222    def __call__(self, batch):

ආදානදත්ත ටෙන්සරය, සමඟ ආරම්භ කර ඇත padding_token

228        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)

හිස්ලේබල් ටෙන්සර්

230        labels = torch.zeros(len(batch), dtype=torch.long)

සාම්පලහරහා ලූප

233        for (i, (_label, _text)) in enumerate(batch):

ලේබලයසකසන්න

235            labels[i] = int(_label) - 1

ආදානපෙළ ටෝකනයිස් කරන්න

237            _text = [self.vocab[token] for token in self.tokenizer(_text)]

දක්වාටන්ක කරන්න seq_len

239            _text = _text[:self.seq_len]

දත්තසම්ප්රේෂණය කර එකතු කරන්න

241            data[:len(_text), i] = data.new_tensor(_text)

අනුපිළිවෙලෙහිඅවසාන ටෝකනය සකසන්න [CLS]

244        data[-1, :] = self.classifier_token

247        return data, labels

AGපුවත් දත්ත කට්ටලය

මෙයAG පුවත් දත්ත සමුදාය පටවන අතර n_classes , සඳහා අගයන් සකසා ඇත, vocab , train_loader , සහ valid_loader .

250@option([NLPClassificationConfigs.n_classes,
251         NLPClassificationConfigs.vocab,
252         NLPClassificationConfigs.train_loader,
253         NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):

පුහුණුවසහ වලංගු කිරීමේ දත්ත කට්ටල ලබා ගන්න

263    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))

මතකයවෙත දත්ත පූරණය කරන්න

266    with monit.section('Load data'):
267        from labml_nn.utils import MapStyleDataset
270        train, valid = MapStyleDataset(train), MapStyleDataset(valid)

ටෝකනයිසර්ලබා ගන්න

273    tokenizer = c.tokenizer

කවුන්ටරයක්සාදන්න

276    counter = Counter()

පුහුණුදත්ත කට්ටලයෙන් ටෝකන එකතු කරන්න

278    for (label, line) in train:
279        counter.update(tokenizer(line))

වලංගුදත්ත කට්ටලයෙන් ටෝකන එකතු කරන්න

281    for (label, line) in valid:
282        counter.update(tokenizer(line))

වචනමාලාව සාදන්න

284    vocab = torchtext.vocab.vocab(counter, min_freq=1)

පුහුණුදත්ත පැටවුම සාදන්න

287    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

වලංගුදත්ත පැටවුම සාදන්න

290    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

ආපසු n_classes vocab , train_loader , සහ valid_loader

294    return 4, vocab, train_loader, valid_loader