mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-31 02:39:16 +08:00
shuffle data
This commit is contained in:
@ -6,11 +6,12 @@ from labml import lab, experiment, monit, tracker, logger
|
||||
from labml.configs import option
|
||||
from labml.logger import Text
|
||||
from labml.utils.pytorch import get_modules
|
||||
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset
|
||||
from labml_helpers.datasets.text import TextDataset, TextFileDataset, SequentialUnBatchedDataset
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.optimizer import OptimizerConfigs
|
||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
|
||||
|
||||
@ -48,6 +49,14 @@ class CrossEntropyLoss(Module):
|
||||
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
|
||||
|
||||
|
||||
def transpose_batch(batch):
|
||||
transposed_data = list(zip(*batch))
|
||||
src = torch.stack(transposed_data[0], 1)
|
||||
tgt = torch.stack(transposed_data[1], 1)
|
||||
|
||||
return src, tgt
|
||||
|
||||
|
||||
class Configs(SimpleTrainValidConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
@ -78,16 +87,20 @@ class Configs(SimpleTrainValidConfigs):
|
||||
self.optimizer = optimizer
|
||||
|
||||
# Create a sequential data loader for training
|
||||
self.train_loader = SequentialDataLoader(text=self.text.train,
|
||||
self.train_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.train,
|
||||
dataset=self.text,
|
||||
seq_len=self.seq_len),
|
||||
batch_size=self.batch_size,
|
||||
seq_len=self.seq_len)
|
||||
collate_fn=transpose_batch,
|
||||
shuffle=True)
|
||||
|
||||
# Create a sequential data loader for validation
|
||||
self.valid_loader = SequentialDataLoader(text=self.text.valid,
|
||||
self.valid_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.valid,
|
||||
dataset=self.text,
|
||||
seq_len=self.seq_len),
|
||||
batch_size=self.batch_size,
|
||||
seq_len=self.seq_len)
|
||||
collate_fn=transpose_batch,
|
||||
shuffle=True)
|
||||
|
||||
self.state_modules = [self.accuracy]
|
||||
|
||||
@ -186,12 +199,12 @@ def main():
|
||||
# A dictionary of configurations to override
|
||||
{'tokenizer': 'character',
|
||||
'text': 'tiny_shakespeare',
|
||||
'optimizer.learning_rate': 1e-4,
|
||||
'optimizer.learning_rate': 2.5e-4,
|
||||
|
||||
'seq_len': 512,
|
||||
'epochs': 128,
|
||||
'batch_size': 2,
|
||||
'inner_iterations': 10})
|
||||
'inner_iterations': 25})
|
||||
|
||||
# This is needed to initialize models
|
||||
conf.n_tokens = conf.text.n_tokens
|
||||
|
||||
Reference in New Issue
Block a user