mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	shuffle data
This commit is contained in:
		| @ -6,11 +6,12 @@ from labml import lab, experiment, monit, tracker, logger | ||||
| from labml.configs import option | ||||
| from labml.logger import Text | ||||
| from labml.utils.pytorch import get_modules | ||||
| from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset | ||||
| from labml_helpers.datasets.text import TextDataset, TextFileDataset, SequentialUnBatchedDataset | ||||
| from labml_helpers.metrics.accuracy import Accuracy | ||||
| from labml_helpers.module import Module | ||||
| from labml_helpers.optimizer import OptimizerConfigs | ||||
| from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex | ||||
| from torch.utils.data import DataLoader | ||||
|  | ||||
| from labml_nn.hypernetworks.hyper_lstm import HyperLSTM | ||||
|  | ||||
| @ -48,6 +49,14 @@ class CrossEntropyLoss(Module): | ||||
|         return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) | ||||
|  | ||||
|  | ||||
| def transpose_batch(batch): | ||||
|     transposed_data = list(zip(*batch)) | ||||
|     src = torch.stack(transposed_data[0], 1) | ||||
|     tgt = torch.stack(transposed_data[1], 1) | ||||
|  | ||||
|     return src, tgt | ||||
|  | ||||
|  | ||||
| class Configs(SimpleTrainValidConfigs): | ||||
|     """ | ||||
|     ## Configurations | ||||
| @ -78,16 +87,20 @@ class Configs(SimpleTrainValidConfigs): | ||||
|         self.optimizer = optimizer | ||||
|  | ||||
|         # Create a sequential data loader for training | ||||
|         self.train_loader = SequentialDataLoader(text=self.text.train, | ||||
|         self.train_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.train, | ||||
|                                                                   dataset=self.text, | ||||
|                                                                   seq_len=self.seq_len), | ||||
|                                        batch_size=self.batch_size, | ||||
|                                                  seq_len=self.seq_len) | ||||
|                                        collate_fn=transpose_batch, | ||||
|                                        shuffle=True) | ||||
|  | ||||
|         # Create a sequential data loader for validation | ||||
|         self.valid_loader = SequentialDataLoader(text=self.text.valid, | ||||
|         self.valid_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.valid, | ||||
|                                                                   dataset=self.text, | ||||
|                                                                   seq_len=self.seq_len), | ||||
|                                        batch_size=self.batch_size, | ||||
|                                                  seq_len=self.seq_len) | ||||
|                                        collate_fn=transpose_batch, | ||||
|                                        shuffle=True) | ||||
|  | ||||
|         self.state_modules = [self.accuracy] | ||||
|  | ||||
| @ -186,12 +199,12 @@ def main(): | ||||
|                        # A dictionary of configurations to override | ||||
|                        {'tokenizer': 'character', | ||||
|                         'text': 'tiny_shakespeare', | ||||
|                         'optimizer.learning_rate': 1e-4, | ||||
|                         'optimizer.learning_rate': 2.5e-4, | ||||
|  | ||||
|                         'seq_len': 512, | ||||
|                         'epochs': 128, | ||||
|                         'batch_size': 2, | ||||
|                         'inner_iterations': 10}) | ||||
|                         'inner_iterations': 25}) | ||||
|  | ||||
|     # This is needed to initialize models | ||||
|     conf.n_tokens = conf.text.n_tokens | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri