mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 14:29:43 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			63 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			63 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
---
 | 
						|
title: "Parity Task"
 | 
						|
summary: >
 | 
						|
  This creates data for Parity Task from the paper Adaptive Computation Time
 | 
						|
  for Recurrent Neural Networks
 | 
						|
---
 | 
						|
 | 
						|
# Parity Task
 | 
						|
 | 
						|
This creates data for Parity Task from the paper
 | 
						|
[Adaptive Computation Time for Recurrent Neural Networks](https://papers.labml.ai/paper/1603.08983).
 | 
						|
 | 
						|
The input of the parity task is a vector with $0$'s $1$'s and $-1$'s.
 | 
						|
The output is the parity of $1$'s - one if there is an odd number of $1$'s and zero otherwise.
 | 
						|
The input is generated by making a random number of elements in the vector either $1$ or $-1$'s.
 | 
						|
"""
 | 
						|
 | 
						|
from typing import Tuple
 | 
						|
 | 
						|
import torch
 | 
						|
from torch.utils.data import Dataset
 | 
						|
 | 
						|
 | 
						|
class ParityDataset(Dataset):
 | 
						|
    """
 | 
						|
    ### Parity dataset
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, n_samples: int, n_elems: int = 64):
 | 
						|
        """
 | 
						|
        * `n_samples` is the number of samples
 | 
						|
        * `n_elems` is the number of elements in the input vector
 | 
						|
        """
 | 
						|
        self.n_samples = n_samples
 | 
						|
        self.n_elems = n_elems
 | 
						|
 | 
						|
    def __len__(self):
 | 
						|
        """
 | 
						|
        Size of the dataset
 | 
						|
        """
 | 
						|
        return self.n_samples
 | 
						|
 | 
						|
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
						|
        """
 | 
						|
        Generate a sample
 | 
						|
        """
 | 
						|
 | 
						|
        # Empty vector
 | 
						|
        x = torch.zeros((self.n_elems,))
 | 
						|
        # Number of non-zero elements - a random number between $1$ and total number of elements
 | 
						|
        n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()
 | 
						|
        # Fill non-zero elements with $1$'s and $-1$'s
 | 
						|
        x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
 | 
						|
        # Randomly permute the elements
 | 
						|
        x = x[torch.randperm(self.n_elems)]
 | 
						|
 | 
						|
        # The parity
 | 
						|
        y = (x == 1.).sum() % 2
 | 
						|
 | 
						|
        #
 | 
						|
        return x, y
 |