mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 02:07:56 +08:00
63 lines
1.7 KiB
Python
63 lines
1.7 KiB
Python
"""
|
|
---
|
|
title: "Parity Task"
|
|
summary: >
|
|
This creates data for Parity Task from the paper Adaptive Computation Time
|
|
for Recurrent Neural Networks
|
|
---
|
|
|
|
# Parity Task
|
|
|
|
This creates data for Parity Task from the paper
|
|
[Adaptive Computation Time for Recurrent Neural Networks](https://papers.labml.ai/paper/1603.08983).
|
|
|
|
The input of the parity task is a vector with $0$'s $1$'s and $-1$'s.
|
|
The output is the parity of $1$'s - one if there is an odd number of $1$'s and zero otherwise.
|
|
The input is generated by making a random number of elements in the vector either $1$ or $-1$'s.
|
|
"""
|
|
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class ParityDataset(Dataset):
|
|
"""
|
|
### Parity dataset
|
|
"""
|
|
|
|
def __init__(self, n_samples: int, n_elems: int = 64):
|
|
"""
|
|
* `n_samples` is the number of samples
|
|
* `n_elems` is the number of elements in the input vector
|
|
"""
|
|
self.n_samples = n_samples
|
|
self.n_elems = n_elems
|
|
|
|
def __len__(self):
|
|
"""
|
|
Size of the dataset
|
|
"""
|
|
return self.n_samples
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Generate a sample
|
|
"""
|
|
|
|
# Empty vector
|
|
x = torch.zeros((self.n_elems,))
|
|
# Number of non-zero elements - a random number between $1$ and total number of elements
|
|
n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()
|
|
# Fill non-zero elements with $1$'s and $-1$'s
|
|
x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
|
|
# Randomly permute the elements
|
|
x = x[torch.randperm(self.n_elems)]
|
|
|
|
# The parity
|
|
y = (x == 1.).sum() % 2
|
|
|
|
#
|
|
return x, y
|