experiment

This commit is contained in:
Varuna Jayasiri
2022-05-31 22:39:58 +05:30
parent 13686c1d28
commit c08af45b03
2 changed files with 27 additions and 19 deletions

View File

@ -41,7 +41,7 @@ class ArithmeticDataset(Dataset):
rx, ry = x % 10, y % 10 rx, ry = x % 10, y % 10
total = rx + ry + carry total = rx + ry + carry
explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}") explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
x, y, c = x // 10, y // 10, total // 10 x, y, carry = x // 10, y // 10, total // 10
e += 1 e += 1
return ' '.join(explanation) return ' '.join(explanation)
@ -51,11 +51,13 @@ class ArithmeticDataset(Dataset):
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1)) y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
if random.randrange(0, 5) < 1:
return f"x={x}+{y}; x=={x + y}\n"
else:
explanation = self.get_add_explanation(x, y) explanation = self.get_add_explanation(x, y)
return f"x={x}+{y}; {explanation} x=={x + y}\n" return f"x={x}+{y}; {explanation} x=={x + y}\n"
def get_packed_math_input(self): def get_packed_math_input(self):
s = ""
s_enc = [] s_enc = []
while len(s_enc) <= self.seq_len: while len(s_enc) <= self.seq_len:
s_part = self.make_add_problem() s_part = self.make_add_problem()
@ -79,8 +81,8 @@ class ArithmeticDataset(Dataset):
class ArithmeticAutoregression(NLPAutoRegressionConfigs): class ArithmeticAutoregression(NLPAutoRegressionConfigs):
max_digits: int = 4 max_digits: int = 4
train_sequences_per_epoch: int = 1024 train_sequences_per_epoch: int = 2 ** 14
valid_sequences_per_epoch: int = 128 valid_sequences_per_epoch: int = 2 ** 4
train_loader: DataLoader = 'arithmetic_train_loader' train_loader: DataLoader = 'arithmetic_train_loader'
valid_loader: DataLoader = 'arithmetic_valid_loader' valid_loader: DataLoader = 'arithmetic_valid_loader'
@ -106,6 +108,10 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs):
output, *_ = self.model(data) output, *_ = self.model(data)
# Get the model prediction (greedy) # Get the model prediction (greedy)
output = output.argmax(dim=-1).squeeze() output = output.argmax(dim=-1).squeeze()
if dataset.itos[output[-1]] == '\n':
break
# Add the prediction to prompt # Add the prediction to prompt
prompt += self.prompt_separator + dataset.itos[output[-1]] prompt += self.prompt_separator + dataset.itos[output[-1]]
# Add the prediction for logging # Add the prediction for logging
@ -119,14 +125,16 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs):
def arithmetic_train_loader(c: ArithmeticAutoregression): def arithmetic_train_loader(c: ArithmeticAutoregression):
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch), return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
batch_size=c.batch_size, batch_size=c.batch_size,
collate_fn=transpose_batch) collate_fn=transpose_batch,
num_workers=4)
@option(ArithmeticAutoregression.valid_loader) @option(ArithmeticAutoregression.valid_loader)
def arithmetic_valid_loader(c: ArithmeticAutoregression): def arithmetic_valid_loader(c: ArithmeticAutoregression):
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch), return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch),
batch_size=c.batch_size, batch_size=c.batch_size,
collate_fn=transpose_batch) collate_fn=transpose_batch,
num_workers=4)
def _test(): def _test():

View File

@ -37,18 +37,20 @@ calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_
def main(): def main():
# Create experiment # Create experiment
experiment.create(name="rope_arithmetic", comment="rotary_value 1.0, 0.5", writers={'screen', 'labml'}) experiment.create(name="rope_arithmetic", comment="rotary_value 1.0", writers={'screen', 'labml'})
# Create configs # Create configs
conf = Configs() conf = Configs()
# Override configurations # Override configurations
experiment.configs(conf, { experiment.configs(conf, {
'max_digits': 9,
# No fixed positional embeddings # No fixed positional embeddings
'transformer.src_embed': 'no_pos', 'transformer.src_embed': 'no_pos',
'transformer.tgt_embed': 'no_pos', 'transformer.tgt_embed': 'no_pos',
# Encoder with RoPE # Encoder with RoPE
# 'transformer.encoder_attn': 'rotary_value', 'transformer.encoder_attn': 'rotary_value',
'transformer.encoder_attn': 'rotary', # 'transformer.encoder_attn': 'rotary',
# #
'model': 'rotary_pe_transformer', 'model': 'rotary_pe_transformer',
@ -56,29 +58,27 @@ def main():
# Prompt separator is blank # Prompt separator is blank
'prompt_separator': '', 'prompt_separator': '',
# Starting prompt for sampling # Starting prompt for sampling
'prompt': '?x=2345+998;', 'prompt': '?x=123456789+1091919;',
# Use a context size of $256$ # Use a context size of $256$
'seq_len': 128, 'seq_len': 512,
# Train for 32 epochs # Train for 32 epochs
'epochs': 32, 'epochs': 32,
# Batch size $4$ # Batch size $4$
'batch_size': 4, 'batch_size': 16,
# Switch between training and validation for $10$ times # Switch between training and validation for $10$ times
# per epoch # per epoch
'inner_iterations': 10, 'inner_iterations': 10,
# Model size # Model size
'd_model': 256, 'd_model': 128,
'transformer.ffn.d_ff': 1024, 'transformer.ffn.d_ff': 512,
'transformer.n_heads': 8, 'transformer.n_heads': 4,
'transformer.dropout': 0.0, 'transformer.dropout': 0.0,
# Use [Noam optimizer](../../optimizers/noam.html) # Use [Noam optimizer](../../optimizers/noam.html)
'optimizer.optimizer': 'Noam', 'optimizer.optimizer': 'Noam',
'optimizer.learning_rate': 1., 'optimizer.learning_rate': 1.,
'dataloader_shuffle_with_replacement': True
}) })
# Set models for saving and loading # Set models for saving and loading