Files
Varuna Jayasiri 9bef456004 PonderNet (#76)
2021-08-12 15:45:01 +05:30

63 lines
1.7 KiB
Python

"""
---
title: "Parity Task"
summary: >
This creates data for Parity Task from the paper Adaptive Computation Time
for Recurrent Neural Networks
---
# Parity Task
This creates data for Parity Task from the paper
[Adaptive Computation Time for Recurrent Neural Networks](https://papers.labml.ai/paper/1603.08983).
The input of the parity task is a vector with $0$'s $1$'s and $-1$'s.
The output is the parity of $1$'s - one if there is an odd number of $1$'s and zero otherwise.
The input is generated by making a random number of elements in the vector either $1$ or $-1$'s.
"""
from typing import Tuple
import torch
from torch.utils.data import Dataset
class ParityDataset(Dataset):
"""
### Parity dataset
"""
def __init__(self, n_samples: int, n_elems: int = 64):
"""
* `n_samples` is the number of samples
* `n_elems` is the number of elements in the input vector
"""
self.n_samples = n_samples
self.n_elems = n_elems
def __len__(self):
"""
Size of the dataset
"""
return self.n_samples
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate a sample
"""
# Empty vector
x = torch.zeros((self.n_elems,))
# Number of non-zero elements - a random number between $1$ and total number of elements
n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()
# Fill non-zero elements with $1$'s and $-1$'s
x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
# Randomly permute the elements
x = x[torch.randperm(self.n_elems)]
# The parity
y = (x == 1.).sum() % 2
#
return x, y