mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	shuffle data
This commit is contained in:
		| @ -6,11 +6,12 @@ from labml import lab, experiment, monit, tracker, logger | |||||||
| from labml.configs import option | from labml.configs import option | ||||||
| from labml.logger import Text | from labml.logger import Text | ||||||
| from labml.utils.pytorch import get_modules | from labml.utils.pytorch import get_modules | ||||||
| from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset | from labml_helpers.datasets.text import TextDataset, TextFileDataset, SequentialUnBatchedDataset | ||||||
| from labml_helpers.metrics.accuracy import Accuracy | from labml_helpers.metrics.accuracy import Accuracy | ||||||
| from labml_helpers.module import Module | from labml_helpers.module import Module | ||||||
| from labml_helpers.optimizer import OptimizerConfigs | from labml_helpers.optimizer import OptimizerConfigs | ||||||
| from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex | from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  |  | ||||||
| from labml_nn.hypernetworks.hyper_lstm import HyperLSTM | from labml_nn.hypernetworks.hyper_lstm import HyperLSTM | ||||||
|  |  | ||||||
| @ -48,6 +49,14 @@ class CrossEntropyLoss(Module): | |||||||
|         return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) |         return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def transpose_batch(batch): | ||||||
|  |     transposed_data = list(zip(*batch)) | ||||||
|  |     src = torch.stack(transposed_data[0], 1) | ||||||
|  |     tgt = torch.stack(transposed_data[1], 1) | ||||||
|  |  | ||||||
|  |     return src, tgt | ||||||
|  |  | ||||||
|  |  | ||||||
| class Configs(SimpleTrainValidConfigs): | class Configs(SimpleTrainValidConfigs): | ||||||
|     """ |     """ | ||||||
|     ## Configurations |     ## Configurations | ||||||
| @ -78,16 +87,20 @@ class Configs(SimpleTrainValidConfigs): | |||||||
|         self.optimizer = optimizer |         self.optimizer = optimizer | ||||||
|  |  | ||||||
|         # Create a sequential data loader for training |         # Create a sequential data loader for training | ||||||
|         self.train_loader = SequentialDataLoader(text=self.text.train, |         self.train_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.train, | ||||||
|                                                  dataset=self.text, |                                                                   dataset=self.text, | ||||||
|                                                  batch_size=self.batch_size, |                                                                   seq_len=self.seq_len), | ||||||
|                                                  seq_len=self.seq_len) |                                        batch_size=self.batch_size, | ||||||
|  |                                        collate_fn=transpose_batch, | ||||||
|  |                                        shuffle=True) | ||||||
|  |  | ||||||
|         # Create a sequential data loader for validation |         # Create a sequential data loader for validation | ||||||
|         self.valid_loader = SequentialDataLoader(text=self.text.valid, |         self.valid_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.valid, | ||||||
|                                                  dataset=self.text, |                                                                   dataset=self.text, | ||||||
|                                                  batch_size=self.batch_size, |                                                                   seq_len=self.seq_len), | ||||||
|                                                  seq_len=self.seq_len) |                                        batch_size=self.batch_size, | ||||||
|  |                                        collate_fn=transpose_batch, | ||||||
|  |                                        shuffle=True) | ||||||
|  |  | ||||||
|         self.state_modules = [self.accuracy] |         self.state_modules = [self.accuracy] | ||||||
|  |  | ||||||
| @ -186,12 +199,12 @@ def main(): | |||||||
|                        # A dictionary of configurations to override |                        # A dictionary of configurations to override | ||||||
|                        {'tokenizer': 'character', |                        {'tokenizer': 'character', | ||||||
|                         'text': 'tiny_shakespeare', |                         'text': 'tiny_shakespeare', | ||||||
|                         'optimizer.learning_rate': 1e-4, |                         'optimizer.learning_rate': 2.5e-4, | ||||||
|  |  | ||||||
|                         'seq_len': 512, |                         'seq_len': 512, | ||||||
|                         'epochs': 128, |                         'epochs': 128, | ||||||
|                         'batch_size': 2, |                         'batch_size': 2, | ||||||
|                         'inner_iterations': 10}) |                         'inner_iterations': 25}) | ||||||
|  |  | ||||||
|     # This is needed to initialize models |     # This is needed to initialize models | ||||||
|     conf.n_tokens = conf.text.n_tokens |     conf.n_tokens = conf.text.n_tokens | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri