mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 16:50:39 +08:00
experiment
This commit is contained in:
@ -41,7 +41,7 @@ class ArithmeticDataset(Dataset):
|
|||||||
rx, ry = x % 10, y % 10
|
rx, ry = x % 10, y % 10
|
||||||
total = rx + ry + carry
|
total = rx + ry + carry
|
||||||
explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
|
explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
|
||||||
x, y, c = x // 10, y // 10, total // 10
|
x, y, carry = x // 10, y // 10, total // 10
|
||||||
e += 1
|
e += 1
|
||||||
|
|
||||||
return ' '.join(explanation)
|
return ' '.join(explanation)
|
||||||
@ -51,11 +51,13 @@ class ArithmeticDataset(Dataset):
|
|||||||
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
|
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
|
||||||
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
|
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
|
||||||
|
|
||||||
|
if random.randrange(0, 5) < 1:
|
||||||
|
return f"x={x}+{y}; x=={x + y}\n"
|
||||||
|
else:
|
||||||
explanation = self.get_add_explanation(x, y)
|
explanation = self.get_add_explanation(x, y)
|
||||||
return f"x={x}+{y}; {explanation} x=={x + y}\n"
|
return f"x={x}+{y}; {explanation} x=={x + y}\n"
|
||||||
|
|
||||||
def get_packed_math_input(self):
|
def get_packed_math_input(self):
|
||||||
s = ""
|
|
||||||
s_enc = []
|
s_enc = []
|
||||||
while len(s_enc) <= self.seq_len:
|
while len(s_enc) <= self.seq_len:
|
||||||
s_part = self.make_add_problem()
|
s_part = self.make_add_problem()
|
||||||
@ -79,8 +81,8 @@ class ArithmeticDataset(Dataset):
|
|||||||
|
|
||||||
class ArithmeticAutoregression(NLPAutoRegressionConfigs):
|
class ArithmeticAutoregression(NLPAutoRegressionConfigs):
|
||||||
max_digits: int = 4
|
max_digits: int = 4
|
||||||
train_sequences_per_epoch: int = 1024
|
train_sequences_per_epoch: int = 2 ** 14
|
||||||
valid_sequences_per_epoch: int = 128
|
valid_sequences_per_epoch: int = 2 ** 4
|
||||||
train_loader: DataLoader = 'arithmetic_train_loader'
|
train_loader: DataLoader = 'arithmetic_train_loader'
|
||||||
valid_loader: DataLoader = 'arithmetic_valid_loader'
|
valid_loader: DataLoader = 'arithmetic_valid_loader'
|
||||||
|
|
||||||
@ -106,6 +108,10 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs):
|
|||||||
output, *_ = self.model(data)
|
output, *_ = self.model(data)
|
||||||
# Get the model prediction (greedy)
|
# Get the model prediction (greedy)
|
||||||
output = output.argmax(dim=-1).squeeze()
|
output = output.argmax(dim=-1).squeeze()
|
||||||
|
|
||||||
|
if dataset.itos[output[-1]] == '\n':
|
||||||
|
break
|
||||||
|
|
||||||
# Add the prediction to prompt
|
# Add the prediction to prompt
|
||||||
prompt += self.prompt_separator + dataset.itos[output[-1]]
|
prompt += self.prompt_separator + dataset.itos[output[-1]]
|
||||||
# Add the prediction for logging
|
# Add the prediction for logging
|
||||||
@ -119,14 +125,16 @@ class ArithmeticAutoregression(NLPAutoRegressionConfigs):
|
|||||||
def arithmetic_train_loader(c: ArithmeticAutoregression):
|
def arithmetic_train_loader(c: ArithmeticAutoregression):
|
||||||
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
|
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
|
||||||
batch_size=c.batch_size,
|
batch_size=c.batch_size,
|
||||||
collate_fn=transpose_batch)
|
collate_fn=transpose_batch,
|
||||||
|
num_workers=4)
|
||||||
|
|
||||||
|
|
||||||
@option(ArithmeticAutoregression.valid_loader)
|
@option(ArithmeticAutoregression.valid_loader)
|
||||||
def arithmetic_valid_loader(c: ArithmeticAutoregression):
|
def arithmetic_valid_loader(c: ArithmeticAutoregression):
|
||||||
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch),
|
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.valid_sequences_per_epoch),
|
||||||
batch_size=c.batch_size,
|
batch_size=c.batch_size,
|
||||||
collate_fn=transpose_batch)
|
collate_fn=transpose_batch,
|
||||||
|
num_workers=4)
|
||||||
|
|
||||||
|
|
||||||
def _test():
|
def _test():
|
||||||
|
@ -37,18 +37,20 @@ calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create experiment
|
# Create experiment
|
||||||
experiment.create(name="rope_arithmetic", comment="rotary_value 1.0, 0.5", writers={'screen', 'labml'})
|
experiment.create(name="rope_arithmetic", comment="rotary_value 1.0", writers={'screen', 'labml'})
|
||||||
# Create configs
|
# Create configs
|
||||||
conf = Configs()
|
conf = Configs()
|
||||||
# Override configurations
|
# Override configurations
|
||||||
experiment.configs(conf, {
|
experiment.configs(conf, {
|
||||||
|
'max_digits': 9,
|
||||||
|
|
||||||
# No fixed positional embeddings
|
# No fixed positional embeddings
|
||||||
'transformer.src_embed': 'no_pos',
|
'transformer.src_embed': 'no_pos',
|
||||||
'transformer.tgt_embed': 'no_pos',
|
'transformer.tgt_embed': 'no_pos',
|
||||||
|
|
||||||
# Encoder with RoPE
|
# Encoder with RoPE
|
||||||
# 'transformer.encoder_attn': 'rotary_value',
|
'transformer.encoder_attn': 'rotary_value',
|
||||||
'transformer.encoder_attn': 'rotary',
|
# 'transformer.encoder_attn': 'rotary',
|
||||||
|
|
||||||
#
|
#
|
||||||
'model': 'rotary_pe_transformer',
|
'model': 'rotary_pe_transformer',
|
||||||
@ -56,29 +58,27 @@ def main():
|
|||||||
# Prompt separator is blank
|
# Prompt separator is blank
|
||||||
'prompt_separator': '',
|
'prompt_separator': '',
|
||||||
# Starting prompt for sampling
|
# Starting prompt for sampling
|
||||||
'prompt': '?x=2345+998;',
|
'prompt': '?x=123456789+1091919;',
|
||||||
|
|
||||||
# Use a context size of $256$
|
# Use a context size of $256$
|
||||||
'seq_len': 128,
|
'seq_len': 512,
|
||||||
# Train for 32 epochs
|
# Train for 32 epochs
|
||||||
'epochs': 32,
|
'epochs': 32,
|
||||||
# Batch size $4$
|
# Batch size $4$
|
||||||
'batch_size': 4,
|
'batch_size': 16,
|
||||||
# Switch between training and validation for $10$ times
|
# Switch between training and validation for $10$ times
|
||||||
# per epoch
|
# per epoch
|
||||||
'inner_iterations': 10,
|
'inner_iterations': 10,
|
||||||
|
|
||||||
# Model size
|
# Model size
|
||||||
'd_model': 256,
|
'd_model': 128,
|
||||||
'transformer.ffn.d_ff': 1024,
|
'transformer.ffn.d_ff': 512,
|
||||||
'transformer.n_heads': 8,
|
'transformer.n_heads': 4,
|
||||||
'transformer.dropout': 0.0,
|
'transformer.dropout': 0.0,
|
||||||
|
|
||||||
# Use [Noam optimizer](../../optimizers/noam.html)
|
# Use [Noam optimizer](../../optimizers/noam.html)
|
||||||
'optimizer.optimizer': 'Noam',
|
'optimizer.optimizer': 'Noam',
|
||||||
'optimizer.learning_rate': 1.,
|
'optimizer.learning_rate': 1.,
|
||||||
|
|
||||||
'dataloader_shuffle_with_replacement': True
|
|
||||||
})
|
})
|
||||||
|
|
||||||
# Set models for saving and loading
|
# Set models for saving and loading
|
||||||
|
Reference in New Issue
Block a user