11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs30class NLPClassificationConfigs(TrainValidConfigs):优化器
41    optimizer: torch.optim.Adam训练设备
43    device: torch.device = DeviceConfigs()自回归模型
46    model: Module批量大小
48    batch_size: int = 16序列的长度或上下文大小
50    seq_len: int = 512词汇
52    vocab: Vocab = 'ag_news'词汇中的代币数量
54    n_tokens: int班级数
56    n_classes: int = 'ag_news'分词器
58    tokenizer: Callable = 'character'是否定期保存模型
61    is_save_models = True亏损函数
64    loss_func = nn.CrossEntropyLoss()精度函数
66    accuracy = Accuracy()模型嵌入大小
68    d_model: int = 512渐变剪切
70    grad_norm_clip: float = 1.0训练数据加载器
73    train_loader: DataLoader = 'ag_news'验证数据加载器
75    valid_loader: DataLoader = 'ag_news'是否记录模型参数和梯度(每个纪元一次)。这些是每层的汇总统计数据,但它仍然可能导致非常深的网络的许多指标。
80    is_log_model_params_grads: bool = False是否记录模型激活(每个纪元一次)。这些是每层的汇总统计数据,但它仍然可能导致非常深的网络的许多指标。
85    is_log_model_activations: bool = False87    def init(self):设置跟踪器配置
92        tracker.set_scalar("accuracy.*", True)
93        tracker.set_scalar("loss.*", True)向日志模块输出添加钩子
95        hook_model_outputs(self.mode, self.model, 'model')增加作为状态模块的精度。这个名字可能令人困惑,因为它旨在存储 RNN 的训练和验证之间的状态。这将使精度指标统计数据分开,以便进行训练和验证。
100        self.state_modules = [self.accuracy]102    def step(self, batch: any, batch_idx: BatchIndex):将数据移动到设备
108        data, target = batch[0].to(self.device), batch[1].to(self.device)在训练模式下更新全局步长(处理的令牌数)
111        if self.mode.is_train:
112            tracker.add_global_step(data.shape[1])是否捕获模型输出
115        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):获取模型输出。它在使用 RNN 时返回状态的元组。这还没有实现。😜
119            output, *_ = self.model(data)计算并记录损失
122        loss = self.loss_func(output, target)
123        tracker.add("loss.", loss)计算和记录精度
126        self.accuracy(output, target)
127        self.accuracy.track()训练模型
130        if self.mode.is_train:计算梯度
132            loss.backward()剪辑渐变
134            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)采取优化器步骤
136            self.optimizer.step()记录每个纪元最后一批的模型参数和梯度
138            if batch_idx.is_last and self.is_log_model_params_grads:
139                tracker.add('model', self.model)清除渐变
141            self.optimizer.zero_grad()保存跟踪的指标
144        tracker.save()147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):153    optimizer = OptimizerConfigs()
154    optimizer.parameters = c.model.parameters()
155    optimizer.optimizer = 'Adam'
156    optimizer.d_model = c.d_model
157
158    return optimizer161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():176    from torchtext.data import get_tokenizer
177    return get_tokenizer('basic_english')180def character_tokenizer(x: str):184    return list(x)角色级别分词器配置
187@option(NLPClassificationConfigs.tokenizer)
188def character():192    return character_tokenizer获取代币数量
195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):200    return len(c.vocab) + 2203class CollateFunc:tokenizer
是分词器函数vocab
是词汇seq_len
是序列的长度padding_token
是大于文本长度时seq_len
用于填充的标记classifier_token
是我们在输入末尾设置的[CLS]
令牌208    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):216        self.classifier_token = classifier_token
217        self.padding_token = padding_token
218        self.seq_len = seq_len
219        self.vocab = vocab
220        self.tokenizer = tokenizerbatch
是由DataLoader
222    def __call__(self, batch):输入数据张量,初始化为padding_token
228        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)空标签张量
230        labels = torch.zeros(len(batch), dtype=torch.long)循环浏览样本
233        for (i, (_label, _text)) in enumerate(batch):设置标签
235            labels[i] = int(_label) - 1标记输入文本
237            _text = [self.vocab[token] for token in self.tokenizer(_text)]截断最多seq_len
239            _text = _text[:self.seq_len]转置并添加到数据
241            data[:len(_text), i] = data.new_tensor(_text)将序列中的最后一个令牌设置为[CLS]
244        data[-1, :] = self.classifier_token247        return data, labels250@option([NLPClassificationConfigs.n_classes,
251         NLPClassificationConfigs.vocab,
252         NLPClassificationConfigs.train_loader,
253         NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):获取训练和验证数据集
263    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))将数据加载到内存
266    with monit.section('Load data'):
267        from labml_nn.utils import MapStyleDataset获取分词器
273    tokenizer = c.tokenizer创建计数器
276    counter = Counter()从训练数据集中收集令牌
278    for (label, line) in train:
279        counter.update(tokenizer(line))从验证数据集中收集令牌
281    for (label, line) in valid:
282        counter.update(tokenizer(line))创建词汇
284    vocab = torchtext.vocab.vocab(counter, min_freq=1)创建训练数据加载器
287    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))创建验证数据加载器
290    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))返回n_classes
vocab
、train_loader
、和valid_loader
294    return 4, vocab, train_loader, valid_loader