experiment

This commit is contained in:
Varuna Jayasiri
2022-05-31 22:39:58 +05:30
parent 13686c1d28
commit c08af45b03
2 changed files with 27 additions and 19 deletions

View File

@ -41,7 +41,7 @@ class ArithmeticDataset(Dataset):
rx, ry = x % 10, y % 10
total = rx + ry + carry
explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
x, y, c = x // 10, y // 10, total // 10
x, y, carry = x // 10, y // 10, total // 10
e += 1
return ' '.join(explanation)
@ -51,11 +51,13 @@ class ArithmeticDataset(Dataset):
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
explanation = self.get_add_explanation(x, y)
return f"x={x}+{y}; {explanation} x=={x + y}\n"
if random.randrange(0, 5) < 1:
return f"x={x}+{y}; x=={x + y}\n"
else:
explanation = self.get_add_explanation(x, y)
return f"x={x}+{y}; {explanation} x=={x + y}\n"
def get_packed_math_input(self):
s = ""
s_enc = []
while len(s_enc) <= self.seq_len:
s_part = self.make_add_problem()
@ -79,8 +81,8 @@ class ArithmeticDataset(Dataset):
class ArithmeticAutoregression(NLPAutoRegressionConfigs):
max_digits: int = 4
train_sequences_per_epoch: int = 1024
valid_sequences_per_epoch: int = 128
train_sequences_per_epoch: int = 2 ** 14
valid_sequences_per_epoch: int = 2 ** 4
train_loader: DataLoader = 'arithmetic_train_loader'
valid_loader: DataLoader = 'arithmetic_valid_loader'
@ -106,6 +108,10 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs):
output, *_ = self.model(data)
# Get the model prediction (greedy)
output = output.argmax(dim=-1).squeeze()
if dataset.itos[output[-1]] == '\n':
break
# Add the prediction to prompt
prompt += self.prompt_separator + dataset.itos[output[-1]]
# Add the prediction for logging
@ -119,14 +125,16 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs):
def arithmetic_train_loader(c: ArithmeticAutoregression):
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
batch_size=c.batch_size,
collate_fn=transpose_batch)
collate_fn=transpose_batch,
num_workers=4)
@option(ArithmeticAutoregression.valid_loader)
def arithmetic_valid_loader(c: ArithmeticAutoregression):
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch),
batch_size=c.batch_size,
collate_fn=transpose_batch)
collate_fn=transpose_batch,
num_workers=4)
def _test():

View File

@ -37,18 +37,20 @@ calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_
def main():
# Create experiment
experiment.create(name="rope_arithmetic", comment="rotary_value 1.0, 0.5", writers={'screen', 'labml'})
experiment.create(name="rope_arithmetic", comment="rotary_value 1.0", writers={'screen', 'labml'})
# Create configs
conf = Configs()
# Override configurations
experiment.configs(conf, {
'max_digits': 9,
# No fixed positional embeddings
'transformer.src_embed': 'no_pos',
'transformer.tgt_embed': 'no_pos',
# Encoder with RoPE
# 'transformer.encoder_attn': 'rotary_value',
'transformer.encoder_attn': 'rotary',
'transformer.encoder_attn': 'rotary_value',
# 'transformer.encoder_attn': 'rotary',
#
'model': 'rotary_pe_transformer',
@ -56,29 +58,27 @@ def main():
# Prompt separator is blank
'prompt_separator': '',
# Starting prompt for sampling
'prompt': '?x=2345+998;',
'prompt': '?x=123456789+1091919;',
# Use a context size of $256$
'seq_len': 128,
'seq_len': 512,
# Train for 32 epochs
'epochs': 32,
# Batch size $4$
'batch_size': 4,
'batch_size': 16,
# Switch between training and validation for $10$ times
# per epoch
'inner_iterations': 10,
# Model size
'd_model': 256,
'transformer.ffn.d_ff': 1024,
'transformer.n_heads': 8,
'd_model': 128,
'transformer.ffn.d_ff': 512,
'transformer.n_heads': 4,
'transformer.dropout': 0.0,
# Use [Noam optimizer](../../optimizers/noam.html)
'optimizer.optimizer': 'Noam',
'optimizer.learning_rate': 1.,
'dataloader_shuffle_with_replacement': True
})
# Set models for saving and loading