From e409e9bf98d214347e57ccf67e7e2c8089625b97 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 1 Jun 2022 14:07:27 +0530 Subject: [PATCH] arthmetic test score --- labml_nn/experiments/arithmetic_dataset.py | 76 ++++++++++++------- .../rope/value_pe/arithmetic_experiment.py | 9 +-- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/labml_nn/experiments/arithmetic_dataset.py b/labml_nn/experiments/arithmetic_dataset.py index e75335f2..d7f93665 100644 --- a/labml_nn/experiments/arithmetic_dataset.py +++ b/labml_nn/experiments/arithmetic_dataset.py @@ -9,9 +9,8 @@ from typing import List import torch from torch.utils.data import DataLoader, Dataset -from labml import monit, logger +from labml import monit, logger, tracker from labml.configs import option -from labml.logger import Text from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch @@ -57,6 +56,12 @@ class ArithmeticDataset(Dataset): explanation = self.get_add_explanation(x, y) return f"x={x}+{y}; {explanation} x=={x + y}\n" + def get_qa(self): + x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) + + return f'x={x}+{y};', f'{x + y}' + def get_packed_math_input(self): s_enc = [] while len(s_enc) <= self.seq_len: @@ -81,10 +86,11 @@ class ArithmeticDataset(Dataset): class ArithmeticAutoregression(NLPAutoRegressionConfigs): max_digits: int = 4 - train_sequences_per_epoch: int = 2 ** 14 - valid_sequences_per_epoch: int = 2 ** 4 + train_sequences_per_epoch: int = 2 ** 12 train_loader: DataLoader = 'arithmetic_train_loader' - valid_loader: DataLoader = 'arithmetic_valid_loader' + n_tests: int = 32 + validator = None + inner_iterations = 4 n_tokens = len(ArithmeticDataset(1, 1, 1).itos) @@ -93,32 +99,52 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs): ### Sampling function to generate samples periodically while training """ - # Starting prompt - prompt = self.prompt - # Collect output for printing - log = [(prompt, Text.subtle)] - # Dataset for decoding + if self.training_loop.idx < 1: + return + dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1) + qa = [dataset.get_qa() for _ in range(self.n_tests)] + prompt = [p[0] for p in qa] + + data = torch.tensor([[dataset.stoi[p[0]] for p in prompt]]) + data = data.to(self.device) + + finished = torch.zeros((len(prompt),)).bool().to(self.device) + new_line = dataset.stoi['\n'] + + results = [p[0] for p in prompt] + # Sample 25 tokens - for i in monit.iterate('Sample', self.seq_len - len(prompt)): + for i in monit.iterate('Sample', self.seq_len - 1): # Tokenize the prompt - data = torch.tensor(dataset.encode(prompt))[:, None] - data = data.to(self.device) # Get the model output output, *_ = self.model(data) # Get the model prediction (greedy) - output = output.argmax(dim=-1).squeeze() + output = output[-1].argmax(dim=-1) - if dataset.itos[output[-1]] == '\n': + finished = finished | (output == new_line) + if finished.sum() == len(finished): break - # Add the prediction to prompt - prompt += self.prompt_separator + dataset.itos[output[-1]] - # Add the prediction for logging - log += [(self.prompt_separator + dataset.itos[output[-1]], Text.value)] + for j, p in enumerate(prompt): + if len(p) > i + 1: + output[j] = dataset.stoi[p[i + 1]] - # Print the sampled output - logger.log(log) + data = torch.cat([data, output[None, :]], dim=0) + + for j, c in enumerate(output): + results[j] += dataset.itos[c] + + results = [r.split('\n')[0] for r in results] + logger.log(results[0]) + results = [r.split('x==')[-1] for r in results] + + correct = 0 + for r, _qa in zip(results, qa): + if r == _qa[1]: + correct += 1 + + tracker.save('score', correct / len(results)) @option(ArithmeticAutoregression.train_loader) @@ -129,14 +155,6 @@ def arithmetic_train_loader(c: ArithmeticAutoregression): num_workers=4) -@option(ArithmeticAutoregression.valid_loader) -def arithmetic_valid_loader(c: ArithmeticAutoregression): - return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch), - batch_size=c.batch_size, - collate_fn=transpose_batch, - num_workers=4) - - def _test(): dataset = ArithmeticDataset(256, 8, 10) diff --git a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py index ef4d3c0e..2f65d211 100644 --- a/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py +++ b/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py @@ -26,7 +26,7 @@ class Configs(RoPEConfigs, ArithmeticAutoregression): # , ArithmeticAutoregress def _rotary_value_pe_mha(c: TransformerConfigs): from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention - return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 0.5) + return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.) # Configuration options @@ -42,7 +42,7 @@ def main(): conf = Configs() # Override configurations experiment.configs(conf, { - 'max_digits': 9, + 'max_digits': 6, # No fixed positional embeddings 'transformer.src_embed': 'no_pos', @@ -63,12 +63,9 @@ def main(): # Use a context size of $256$ 'seq_len': 512, # Train for 32 epochs - 'epochs': 32, + 'epochs': 64, # Batch size $4$ 'batch_size': 16, - # Switch between training and validation for $10$ times - # per epoch - 'inner_iterations': 10, # Model size 'd_model': 128,