sophia exp

This commit is contained in:
Varuna Jayasiri
2023-07-14 16:44:45 +05:30
parent f45ca5ee69
commit 7c02294e7c
3 changed files with 162 additions and 13 deletions

View File

@ -67,6 +67,8 @@ class OptimizerConfigs(BaseConfigs):
# Model embedding size for Noam optimizer
d_model: int
rho: float
def __init__(self):
super().__init__(_primary='optimizer')
@ -137,6 +139,14 @@ def _noam_optimizer(c: OptimizerConfigs):
d_model=c.d_model)
@option(OptimizerConfigs.optimizer, 'Sophia')
def _sophia_optimizer(c: OptimizerConfigs):
from labml_nn.optimizers.sophia import Sophia
return Sophia(c.parameters,
lr=c.learning_rate, betas=c.betas, eps=c.eps,
weight_decay=c.weight_decay_obj, rho=c.rho)
@option(OptimizerConfigs.optimizer, 'AdamWarmupCosineDecay')
def _noam_optimizer(c: OptimizerConfigs):
from labml_nn.optimizers.adam_warmup_cosine_decay import AdamWarmupCosineDecay

View File

@ -29,9 +29,7 @@ class Sophia(GenericAdaptiveOptimizer):
def __init__(self, params,
lr: float = 1e-4, betas: Tuple[float, float] = (0.965, 0.99), eps: float = 1e-16,
rho: float = 0.04,
training_batch_tokens: int = None,
weight_decay: WeightDecay = WeightDecay(),
optimized_update: bool = True,
defaults: Optional[Dict[str, Any]] = None):
"""
### Initialize the optimizer
@ -42,21 +40,15 @@ class Sophia(GenericAdaptiveOptimizer):
* `eps` is $\epsilon$
* `pho` is $\rho$
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
* `optimized_update` is a flag whether to optimize the bias correction of the second moment
by doing it after adding $\epsilon$
* `defaults` is a dictionary of default for group values.
This is useful when you want to extend the class `Adam`.
"""
if training_batch_tokens is None:
raise RuntimeError('Please set the number of tokens per training batch.')
defaults = {} if defaults is None else defaults
defaults.update(weight_decay.defaults())
defaults.update(dict(rho=rho, training_batch_tokens=training_batch_tokens))
defaults.update(dict(rho=rho))
super().__init__(params, defaults, lr, betas, eps)
self.weight_decay = weight_decay
self.optimized_update = optimized_update
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
"""
@ -75,7 +67,7 @@ class Sophia(GenericAdaptiveOptimizer):
# Exponential moving average of Hessian
state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format)
def update_hessian(self, batch_size):
def update_hessian(self, n_tokens_training_batch):
for group in self.param_groups:
beta1, beta2 = group['betas']
for p in group['params']:
@ -86,7 +78,7 @@ class Sophia(GenericAdaptiveOptimizer):
if len(state) == 0:
self.init_state(state, group, p)
state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * batch_size)
state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * n_tokens_training_batch)
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
"""
@ -107,7 +99,7 @@ class Sophia(GenericAdaptiveOptimizer):
rho = group['rho']
# Get $m_{t-1}$ and $v_{t-1}$
m, hessian = state['exp_avg'], state['hessain']
m, hessian = state['exp_avg'], state['hessian']
# In-place calculation of $m_t$
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
@ -119,6 +111,6 @@ class Sophia(GenericAdaptiveOptimizer):
# Get learning rate
lr = group['lr']
ratio = (m.abs() / (rho * hessian + group['training_batch_tokens'] * group['eps'])).clamp(None, 1)
ratio = (m.abs() / (rho * hessian + group['eps'])).clamp(None, 1)
param.data.addcmul_(m.sign(), ratio, value=-lr)

View File

@ -0,0 +1,147 @@
import torch
from labml.configs import option
from labml import experiment, tracker
from labml_helpers.train_valid import BatchIndex
from labml_nn.optimizers.sophia import Sophia
from labml_nn.transformers.basic.autoregressive_experiment import Configs as TransformerAutoRegressionConfigs
class Configs(TransformerAutoRegressionConfigs):
"""
## Configurations
This inherits from [`Configs`](autoregressive_experiment.html)
"""
hess_interval: int = 10
optimizer: Sophia
def step(self, batch: any, batch_idx: BatchIndex):
"""
### Training or validation step
"""
# Set training/eval mode
self.model.train(self.mode.is_train)
# Move data to the device
data, target = batch[0].to(self.device), batch[1].to(self.device)
if isinstance(self.optimizer, Sophia) and self.mode.is_train and batch_idx.idx % self.hess_interval == 0:
# Whether to capture model outputs
with self.mode.update(is_log_activations=False):
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜
output, *_ = self.model(data)
samp_dist = torch.distributions.Categorical(logits=output)
y_sample = samp_dist.sample()
# Calculate and log loss
loss = self.loss_func(output, y_sample)
tracker.add("loss.hess.", loss)
# Calculate gradients
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
# Update Hessian estimate
self.optimizer.update_hessian(data.numel())
# Clear the gradients
self.optimizer.zero_grad()
else:
# Move data to the device
data, target = batch[0].to(self.device), batch[1].to(self.device)
# Update global step (number of tokens processed) when in training mode
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Calculate and log loss
loss = self.loss_func(output, target)
tracker.add("loss.", loss)
# Calculate and log accuracy
self.accuracy(output, target)
self.accuracy.track()
self.other_metrics(output, target)
# Train the model
if self.mode.is_train:
# Calculate gradients
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
# Take optimizer step
self.optimizer.step()
# Log the model parameters and gradients on last batch of every epoch
if batch_idx.is_last and self.is_log_model_params_grads:
tracker.add('model', self.model)
# Clear the gradients
self.optimizer.zero_grad()
# Save the tracked metrics
tracker.save()
def main():
# Create experiment
experiment.create(name="transformer")
# Create configs
conf = Configs()
# Override configurations
experiment.configs(conf, {
# Use character level tokenizer
'tokenizer': 'character',
# Prompt separator is blank
'prompt_separator': '',
# Starting prompt for sampling
'prompt': 'It is ',
# Use Tiny Shakespeare dataset
'text': 'tiny_shakespeare',
# Use a context size of $256$
'seq_len': 512,
# Train for 32 epochs
'epochs': 32,
# Batch size $32$
'batch_size': 16,
# Switch between training and validation for $10$ times
# per epoch
'inner_iterations': 10,
# Model size
'd_model': 256,
'transformer.n_heads': 16,
'transformer.ffn.d_ff': 1024,
# Use [Noam optimizer](../../optimizers/noam.html)
'optimizer.optimizer': 'Sophia',
'optimizer.learning_rate': 3e-4,
'optimizer.rho': 0.03,
})
# Set models for saving and loading
experiment.add_pytorch_models({'model': conf.model})
# Start the experiment
with experiment.start():
# Run training
conf.run()
#
if __name__ == '__main__':
main()