""" --- title: "Parity Task" summary: > This creates data for Parity Task from the paper Adaptive Computation Time for Recurrent Neural Networks --- # Parity Task This creates data for Parity Task from the paper [Adaptive Computation Time for Recurrent Neural Networks](https://papers.labml.ai/paper/1603.08983). The input of the parity task is a vector with $0$'s $1$'s and $-1$'s. The output is the parity of $1$'s - one if there is an odd number of $1$'s and zero otherwise. The input is generated by making a random number of elements in the vector either $1$ or $-1$'s. """ from typing import Tuple import torch from torch.utils.data import Dataset class ParityDataset(Dataset): """ ### Parity dataset """ def __init__(self, n_samples: int, n_elems: int = 64): """ * `n_samples` is the number of samples * `n_elems` is the number of elements in the input vector """ self.n_samples = n_samples self.n_elems = n_elems def __len__(self): """ Size of the dataset """ return self.n_samples def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Generate a sample """ # Empty vector x = torch.zeros((self.n_elems,)) # Number of non-zero elements - a random number between $1$ and total number of elements n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item() # Fill non-zero elements with $1$'s and $-1$'s x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1 # Randomly permute the elements x = x[torch.randperm(self.n_elems)] # The parity y = (x == 1.).sum() % 2 # return x, y