mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			63 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			63 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: "Parity Task"
 | |
| summary: >
 | |
|   This creates data for Parity Task from the paper Adaptive Computation Time
 | |
|   for Recurrent Neural Networks
 | |
| ---
 | |
| 
 | |
| # Parity Task
 | |
| 
 | |
| This creates data for Parity Task from the paper
 | |
| [Adaptive Computation Time for Recurrent Neural Networks](https://papers.labml.ai/paper/1603.08983).
 | |
| 
 | |
| The input of the parity task is a vector with $0$'s $1$'s and $-1$'s.
 | |
| The output is the parity of $1$'s - one if there is an odd number of $1$'s and zero otherwise.
 | |
| The input is generated by making a random number of elements in the vector either $1$ or $-1$'s.
 | |
| """
 | |
| 
 | |
| from typing import Tuple
 | |
| 
 | |
| import torch
 | |
| from torch.utils.data import Dataset
 | |
| 
 | |
| 
 | |
| class ParityDataset(Dataset):
 | |
|     """
 | |
|     ### Parity dataset
 | |
|     """
 | |
| 
 | |
|     def __init__(self, n_samples: int, n_elems: int = 64):
 | |
|         """
 | |
|         * `n_samples` is the number of samples
 | |
|         * `n_elems` is the number of elements in the input vector
 | |
|         """
 | |
|         self.n_samples = n_samples
 | |
|         self.n_elems = n_elems
 | |
| 
 | |
|     def __len__(self):
 | |
|         """
 | |
|         Size of the dataset
 | |
|         """
 | |
|         return self.n_samples
 | |
| 
 | |
|     def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """
 | |
|         Generate a sample
 | |
|         """
 | |
| 
 | |
|         # Empty vector
 | |
|         x = torch.zeros((self.n_elems,))
 | |
|         # Number of non-zero elements - a random number between $1$ and total number of elements
 | |
|         n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()
 | |
|         # Fill non-zero elements with $1$'s and $-1$'s
 | |
|         x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
 | |
|         # Randomly permute the elements
 | |
|         x = x[torch.randperm(self.n_elems)]
 | |
| 
 | |
|         # The parity
 | |
|         y = (x == 1.).sum() % 2
 | |
| 
 | |
|         #
 | |
|         return x, y
 | 
