arthmetic test score

This commit is contained in:
Varuna Jayasiri
2022-06-01 14:07:27 +05:30
parent c08af45b03
commit e409e9bf98
2 changed files with 50 additions and 35 deletions

View File

@ -9,9 +9,8 @@ from typing import List
import torch
from torch.utils.data import DataLoader, Dataset
from labml import monit, logger
from labml import monit, logger, tracker
from labml.configs import option
from labml.logger import Text
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
@ -57,6 +56,12 @@ class ArithmeticDataset(Dataset):
explanation = self.get_add_explanation(x, y)
return f"x={x}+{y}; {explanation} x=={x + y}\n"
def get_qa(self):
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
return f'x={x}+{y};', f'{x + y}'
def get_packed_math_input(self):
s_enc = []
while len(s_enc) <= self.seq_len:
@ -81,10 +86,11 @@ class ArithmeticDataset(Dataset):
class ArithmeticAutoregression(NLPAutoRegressionConfigs):
max_digits: int = 4
train_sequences_per_epoch: int = 2 ** 14
valid_sequences_per_epoch: int = 2 ** 4
train_sequences_per_epoch: int = 2 ** 12
train_loader: DataLoader = 'arithmetic_train_loader'
valid_loader: DataLoader = 'arithmetic_valid_loader'
n_tests: int = 32
validator = None
inner_iterations = 4
n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
@ -93,32 +99,52 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs):
### Sampling function to generate samples periodically while training
"""
# Starting prompt
prompt = self.prompt
# Collect output for printing
log = [(prompt, Text.subtle)]
# Dataset for decoding
if self.training_loop.idx < 1:
return
dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
# Sample 25 tokens
for i in monit.iterate('Sample', self.seq_len - len(prompt)):
# Tokenize the prompt
data = torch.tensor(dataset.encode(prompt))[:, None]
qa = [dataset.get_qa() for _ in range(self.n_tests)]
prompt = [p[0] for p in qa]
data = torch.tensor([[dataset.stoi[p[0]] for p in prompt]])
data = data.to(self.device)
finished = torch.zeros((len(prompt),)).bool().to(self.device)
new_line = dataset.stoi['\n']
results = [p[0] for p in prompt]
# Sample 25 tokens
for i in monit.iterate('Sample', self.seq_len - 1):
# Tokenize the prompt
# Get the model output
output, *_ = self.model(data)
# Get the model prediction (greedy)
output = output.argmax(dim=-1).squeeze()
output = output[-1].argmax(dim=-1)
if dataset.itos[output[-1]] == '\n':
finished = finished | (output == new_line)
if finished.sum() == len(finished):
break
# Add the prediction to prompt
prompt += self.prompt_separator + dataset.itos[output[-1]]
# Add the prediction for logging
log += [(self.prompt_separator + dataset.itos[output[-1]], Text.value)]
for j, p in enumerate(prompt):
if len(p) > i + 1:
output[j] = dataset.stoi[p[i + 1]]
# Print the sampled output
logger.log(log)
data = torch.cat([data, output[None, :]], dim=0)
for j, c in enumerate(output):
results[j] += dataset.itos[c]
results = [r.split('\n')[0] for r in results]
logger.log(results[0])
results = [r.split('x==')[-1] for r in results]
correct = 0
for r, _qa in zip(results, qa):
if r == _qa[1]:
correct += 1
tracker.save('score', correct / len(results))
@option(ArithmeticAutoregression.train_loader)
@ -129,14 +155,6 @@ def arithmetic_train_loader(c: ArithmeticAutoregression):
num_workers=4)
@option(ArithmeticAutoregression.valid_loader)
def arithmetic_valid_loader(c: ArithmeticAutoregression):
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch),
batch_size=c.batch_size,
collate_fn=transpose_batch,
num_workers=4)
def _test():
dataset = ArithmeticDataset(256, 8, 10)

View File

@ -26,7 +26,7 @@ class Configs(RoPEConfigs, ArithmeticAutoregression): # , ArithmeticAutoregress
def _rotary_value_pe_mha(c: TransformerConfigs):
from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 0.5)
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
# Configuration options
@ -42,7 +42,7 @@ def main():
conf = Configs()
# Override configurations
experiment.configs(conf, {
'max_digits': 9,
'max_digits': 6,
# No fixed positional embeddings
'transformer.src_embed': 'no_pos',
@ -63,12 +63,9 @@ def main():
# Use a context size of $256$
'seq_len': 512,
# Train for 32 epochs
'epochs': 32,
'epochs': 64,
# Batch size $4$
'batch_size': 16,
# Switch between training and validation for $10$ times
# per epoch
'inner_iterations': 10,
# Model size
'd_model': 128,