mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	optimizers
This commit is contained in:
		
							
								
								
									
										0
									
								
								labml_nn/optimizers/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								labml_nn/optimizers/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										175
									
								
								labml_nn/optimizers/ada_belief/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								labml_nn/optimizers/ada_belief/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,175 @@ | |||||||
|  | """ | ||||||
|  | This is forked from AdaBelief official implementation | ||||||
|  | https://github.com/juntang-zhuang/Adabelief-Optimizer | ||||||
|  | """ | ||||||
|  | import math | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from torch.optim.optimizer import Optimizer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AdaBelief(Optimizer): | ||||||
|  |     r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch | ||||||
|  |     Arguments: | ||||||
|  |         params (iterable): iterable of parameters to optimize or dicts defining | ||||||
|  |             parameter groups | ||||||
|  |         lr (float, optional): learning rate (default: 1e-3) | ||||||
|  |         betas (Tuple[float, float], optional): coefficients used for computing | ||||||
|  |             running averages of gradient and its square (default: (0.9, 0.999)) | ||||||
|  |         eps (float, optional): term added to the denominator to improve | ||||||
|  |             numerical stability (default: 1e-16) | ||||||
|  |         weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | ||||||
|  |         amsgrad (boolean, optional): whether to use the AMSGrad variant of this | ||||||
|  |             algorithm from the paper `On the Convergence of Adam and Beyond`_ | ||||||
|  |             (default: False) | ||||||
|  |         weight_decouple (boolean, optional): ( default: True) If set as True, then | ||||||
|  |             the optimizer uses decoupled weight decay as in AdamW | ||||||
|  |         fixed_decay (boolean, optional): (default: False) This is used when weight_decouple | ||||||
|  |             is set as True. | ||||||
|  |             When fixed_decay == True, the weight decay is performed as | ||||||
|  |             $W_{new} = W_{old} - W_{old} \times decay$. | ||||||
|  |             When fixed_decay == False, the weight decay is performed as | ||||||
|  |             $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the | ||||||
|  |             weight decay ratio decreases with learning rate (lr). | ||||||
|  |         rectify (boolean, optional): (default: True) If set as True, then perform the rectified | ||||||
|  |             update similar to RAdam | ||||||
|  |         degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update | ||||||
|  |             when variance of gradient is high | ||||||
|  |     reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020 | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, | ||||||
|  |                  weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True, | ||||||
|  |                  degenerated_to_sgd=True): | ||||||
|  |         if not 0.0 <= lr: | ||||||
|  |             raise ValueError("Invalid learning rate: {}".format(lr)) | ||||||
|  |         if not 0.0 <= eps: | ||||||
|  |             raise ValueError("Invalid epsilon value: {}".format(eps)) | ||||||
|  |         if not 0.0 <= betas[0] < 1.0: | ||||||
|  |             raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||||||
|  |         if not 0.0 <= betas[1] < 1.0: | ||||||
|  |             raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||||||
|  |         if not 0.0 <= weight_decay: | ||||||
|  |             raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | ||||||
|  |  | ||||||
|  |         if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): | ||||||
|  |             for param in params: | ||||||
|  |                 if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): | ||||||
|  |                     param['buffer'] = [[None, None, None] for _ in range(10)] | ||||||
|  |  | ||||||
|  |         defaults = dict(lr=lr, betas=betas, eps=eps, | ||||||
|  |                         weight_decay=weight_decay, amsgrad=amsgrad, | ||||||
|  |                         buffer=[[None, None, None] for _ in range(10)]) | ||||||
|  |         super().__init__(params, defaults) | ||||||
|  |  | ||||||
|  |         self.degenerated_to_sgd = degenerated_to_sgd | ||||||
|  |         self.weight_decouple = weight_decouple | ||||||
|  |         self.rectify = rectify | ||||||
|  |         self.fixed_decay = fixed_decay | ||||||
|  |  | ||||||
|  |     def __setstate__(self, state): | ||||||
|  |         super().__setstate__(state) | ||||||
|  |         for group in self.param_groups: | ||||||
|  |             group.setdefault('amsgrad', False) | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def step(self, closure=None): | ||||||
|  |         """Performs a single optimization step. | ||||||
|  |         Arguments: | ||||||
|  |             closure (callable, optional): A closure that reevaluates the model | ||||||
|  |                 and returns the loss. | ||||||
|  |         """ | ||||||
|  |         loss = None | ||||||
|  |         if closure is not None: | ||||||
|  |             with torch.enable_grad(): | ||||||
|  |                 loss = closure() | ||||||
|  |  | ||||||
|  |         for group in self.param_groups: | ||||||
|  |             for p in group['params']: | ||||||
|  |                 if p.grad is None: | ||||||
|  |                     continue | ||||||
|  |                 grad = p.grad.data | ||||||
|  |                 if grad.is_sparse: | ||||||
|  |                     raise RuntimeError('AdaBelief does not support sparse gradients,' | ||||||
|  |                                        ' please consider SparseAdam instead') | ||||||
|  |  | ||||||
|  |                 state = self.state[p] | ||||||
|  |                 # Lazy state initialization | ||||||
|  |                 if len(state) == 0: | ||||||
|  |                     state['step'] = 0 | ||||||
|  |                     # Exponential moving average of gradient values | ||||||
|  |                     state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) | ||||||
|  |                     # Exponential moving average of squared gradient values | ||||||
|  |                     state['exp_avg_var'] = torch.zeros_like(p, memory_format=torch.preserve_format) | ||||||
|  |                     if group['amsgrad']: | ||||||
|  |                         # Maintains max of all exp. moving avg. of sq. grad. values | ||||||
|  |                         state['max_exp_avg_var'] = torch.zeros_like(p, memory_format=torch.preserve_format) | ||||||
|  |  | ||||||
|  |                 beta1, beta2 = group['betas'] | ||||||
|  |  | ||||||
|  |                 # get current state variable | ||||||
|  |                 exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] | ||||||
|  |  | ||||||
|  |                 state['step'] += 1 | ||||||
|  |                 bias_correction1 = 1 - beta1 ** state['step'] | ||||||
|  |                 bias_correction2 = 1 - beta2 ** state['step'] | ||||||
|  |  | ||||||
|  |                 # Update first and second moment running average | ||||||
|  |                 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) | ||||||
|  |                 grad_residual = grad - exp_avg | ||||||
|  |                 exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) | ||||||
|  |  | ||||||
|  |                 if group['amsgrad']: | ||||||
|  |                     max_exp_avg_var = state['max_exp_avg_var'] | ||||||
|  |                     # Maintains the maximum of all 2nd moment running avg. till now | ||||||
|  |                     torch.max(max_exp_avg_var, exp_avg_var, out=max_exp_avg_var) | ||||||
|  |  | ||||||
|  |                     # Use the max. for normalizing running avg. of gradient | ||||||
|  |                     denom = ((max_exp_avg_var + group['eps']).sqrt_() / math.sqrt(bias_correction2)).add_(group['eps']) | ||||||
|  |                 else: | ||||||
|  |                     denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) | ||||||
|  |  | ||||||
|  |                 # perform weight decay, check if decoupled weight decay | ||||||
|  |                 if self.weight_decouple: | ||||||
|  |                     if not self.fixed_decay: | ||||||
|  |                         p.data.mul_(1.0 - group['lr'] * group['weight_decay']) | ||||||
|  |                     else: | ||||||
|  |                         p.data.mul_(1.0 - group['weight_decay']) | ||||||
|  |                 else: | ||||||
|  |                     if group['weight_decay'] != 0: | ||||||
|  |                         grad.add_(p.data, alpha=group['weight_decay']) | ||||||
|  |  | ||||||
|  |                 # update | ||||||
|  |                 if not self.rectify: | ||||||
|  |                     # Default update | ||||||
|  |                     step_size = group['lr'] / bias_correction1 | ||||||
|  |                     p.data.addcdiv_(exp_avg, denom, value=-step_size) | ||||||
|  |                 else:  # Rectified update, forked from RAdam | ||||||
|  |                     buffered = group['buffer'][int(state['step'] % 10)] | ||||||
|  |                     if state['step'] == buffered[0]: | ||||||
|  |                         N_sma, step_size = buffered[1], buffered[2] | ||||||
|  |                     else: | ||||||
|  |                         buffered[0] = state['step'] | ||||||
|  |                         beta2_t = beta2 ** state['step'] | ||||||
|  |                         N_sma_max = 2 / (1 - beta2) - 1 | ||||||
|  |                         N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) | ||||||
|  |                         buffered[1] = N_sma | ||||||
|  |  | ||||||
|  |                         # more conservative since it's an approximated value | ||||||
|  |                         if N_sma >= 5: | ||||||
|  |                             step_size = math.sqrt( | ||||||
|  |                                 (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( | ||||||
|  |                                         N_sma_max - 2)) / (1 - beta1 ** state['step']) | ||||||
|  |                         elif self.degenerated_to_sgd: | ||||||
|  |                             step_size = 1.0 / (1 - beta1 ** state['step']) | ||||||
|  |                         else: | ||||||
|  |                             step_size = -1 | ||||||
|  |                         buffered[2] = step_size | ||||||
|  |  | ||||||
|  |                     if N_sma >= 5: | ||||||
|  |                         denom = exp_avg_var.sqrt().add_(group['eps']) | ||||||
|  |                         p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) | ||||||
|  |                     elif step_size > 0: | ||||||
|  |                         p.data.add_(exp_avg, alpha=-step_size * group['lr']) | ||||||
|  |  | ||||||
|  |         return loss | ||||||
							
								
								
									
										111
									
								
								labml_nn/optimizers/ada_belief/mnist.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								labml_nn/optimizers/ada_belief/mnist.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,111 @@ | |||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | import torch.utils.data | ||||||
|  |  | ||||||
|  | from labml import experiment, tracker | ||||||
|  | from labml.configs import option | ||||||
|  | from labml_helpers.datasets.mnist import MNISTConfigs | ||||||
|  | from labml_helpers.device import DeviceConfigs | ||||||
|  | from labml_helpers.metrics.accuracy import Accuracy | ||||||
|  | from labml_helpers.module import Module | ||||||
|  | from labml_helpers.optimizer import OptimizerConfigs | ||||||
|  | from labml_helpers.seed import SeedConfigs | ||||||
|  | from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Net(Module): | ||||||
|  |     def __init__(self): | ||||||
|  |         super().__init__() | ||||||
|  |         self.conv1 = nn.Conv2d(1, 20, 5, 1) | ||||||
|  |         self.conv2 = nn.Conv2d(20, 50, 5, 1) | ||||||
|  |         self.fc1 = nn.Linear(4 * 4 * 50, 500) | ||||||
|  |         self.fc2 = nn.Linear(500, 10) | ||||||
|  |  | ||||||
|  |     def __call__(self, x: torch.Tensor): | ||||||
|  |         x = F.relu(self.conv1(x)) | ||||||
|  |         x = F.max_pool2d(x, 2, 2) | ||||||
|  |         x = F.relu(self.conv2(x)) | ||||||
|  |         x = F.max_pool2d(x, 2, 2) | ||||||
|  |         x = x.view(-1, 4 * 4 * 50) | ||||||
|  |         x = F.relu(self.fc1(x)) | ||||||
|  |         return self.fc2(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Configs(MNISTConfigs, TrainValidConfigs): | ||||||
|  |     optimizer: torch.optim.Adam | ||||||
|  |     model: nn.Module | ||||||
|  |     set_seed = SeedConfigs() | ||||||
|  |     device: torch.device = DeviceConfigs() | ||||||
|  |     epochs: int = 10 | ||||||
|  |  | ||||||
|  |     is_save_models = True | ||||||
|  |     model: nn.Module | ||||||
|  |     inner_iterations = 10 | ||||||
|  |  | ||||||
|  |     accuracy_func = Accuracy() | ||||||
|  |     loss_func = nn.CrossEntropyLoss() | ||||||
|  |  | ||||||
|  |     def init(self): | ||||||
|  |         tracker.set_queue("loss.*", 20, True) | ||||||
|  |         tracker.set_scalar("accuracy.*", True) | ||||||
|  |         hook_model_outputs(self.mode, self.model, 'model') | ||||||
|  |         self.state_modules = [self.accuracy_func] | ||||||
|  |  | ||||||
|  |     def step(self, batch: any, batch_idx: BatchIndex): | ||||||
|  |         data, target = batch[0].to(self.device), batch[1].to(self.device) | ||||||
|  |  | ||||||
|  |         if self.mode.is_train: | ||||||
|  |             tracker.add_global_step(len(data)) | ||||||
|  |  | ||||||
|  |         with self.mode.update(is_log_activations=batch_idx.is_last): | ||||||
|  |             output = self.model(data) | ||||||
|  |  | ||||||
|  |         loss = self.loss_func(output, target) | ||||||
|  |         self.accuracy_func(output, target) | ||||||
|  |         tracker.add("loss.", loss) | ||||||
|  |  | ||||||
|  |         if self.mode.is_train: | ||||||
|  |             loss.backward() | ||||||
|  |  | ||||||
|  |             self.optimizer.step() | ||||||
|  |             if batch_idx.is_last: | ||||||
|  |                 tracker.add('model', self.model) | ||||||
|  |                 tracker.add('optimizer', (self.optimizer, {'model': self.model})) | ||||||
|  |             self.optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |         tracker.save() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @option(Configs.model) | ||||||
|  | def model(c: Configs): | ||||||
|  |     return Net().to(c.device) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @option(OptimizerConfigs.optimizer, 'AdaBelief') | ||||||
|  | def ada_belief(c: OptimizerConfigs): | ||||||
|  |     from labml_nn.optimizers.ada_belief import AdaBelief | ||||||
|  |     return AdaBelief(c.parameters, lr=c.learning_rate, betas=c.betas, eps=c.eps) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @option(Configs.optimizer) | ||||||
|  | def _optimizer(c: Configs): | ||||||
|  |     opt_conf = OptimizerConfigs() | ||||||
|  |     opt_conf.parameters = c.model.parameters() | ||||||
|  |     return opt_conf | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     conf = Configs() | ||||||
|  |     conf.inner_iterations = 10 | ||||||
|  |     experiment.create(name='mnist_ada_belief') | ||||||
|  |     experiment.configs(conf, {'inner_iterations': 10, | ||||||
|  |                               'optimizer.optimizer': 'AdaBelief', | ||||||
|  |                               'optimizer.learning_rate': 1.5e-4}) | ||||||
|  |     conf.set_seed.set() | ||||||
|  |     experiment.add_pytorch_models(dict(model=conf.model)) | ||||||
|  |     with experiment.start(): | ||||||
|  |         conf.run() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
							
								
								
									
										254
									
								
								labml_nn/optimizers/radam/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								labml_nn/optimizers/radam/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,254 @@ | |||||||
|  | """ | ||||||
|  | Forked from https://github.com/LiyuanLucasLiu/RAdam | ||||||
|  | """ | ||||||
|  |  | ||||||
|  | import math | ||||||
|  | import torch | ||||||
|  | from torch.optim.optimizer import Optimizer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RAdam(Optimizer): | ||||||
|  |  | ||||||
|  |     def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): | ||||||
|  |         if not 0.0 <= lr: | ||||||
|  |             raise ValueError("Invalid learning rate: {}".format(lr)) | ||||||
|  |         if not 0.0 <= eps: | ||||||
|  |             raise ValueError("Invalid epsilon value: {}".format(eps)) | ||||||
|  |         if not 0.0 <= betas[0] < 1.0: | ||||||
|  |             raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||||||
|  |         if not 0.0 <= betas[1] < 1.0: | ||||||
|  |             raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||||||
|  |  | ||||||
|  |         self.degenerated_to_sgd = degenerated_to_sgd | ||||||
|  |         if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): | ||||||
|  |             for param in params: | ||||||
|  |                 if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): | ||||||
|  |                     param['buffer'] = [[None, None, None] for _ in range(10)] | ||||||
|  |         defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, | ||||||
|  |                         buffer=[[None, None, None] for _ in range(10)]) | ||||||
|  |         super(RAdam, self).__init__(params, defaults) | ||||||
|  |  | ||||||
|  |     def __setstate__(self, state): | ||||||
|  |         super(RAdam, self).__setstate__(state) | ||||||
|  |  | ||||||
|  |     def step(self, closure=None): | ||||||
|  |  | ||||||
|  |         loss = None | ||||||
|  |         if closure is not None: | ||||||
|  |             loss = closure() | ||||||
|  |  | ||||||
|  |         for group in self.param_groups: | ||||||
|  |  | ||||||
|  |             for p in group['params']: | ||||||
|  |                 if p.grad is None: | ||||||
|  |                     continue | ||||||
|  |                 grad = p.grad.data.float() | ||||||
|  |                 if grad.is_sparse: | ||||||
|  |                     raise RuntimeError('RAdam does not support sparse gradients') | ||||||
|  |  | ||||||
|  |                 p_data_fp32 = p.data.float() | ||||||
|  |  | ||||||
|  |                 state = self.state[p] | ||||||
|  |  | ||||||
|  |                 if len(state) == 0: | ||||||
|  |                     state['step'] = 0 | ||||||
|  |                     state['exp_avg'] = torch.zeros_like(p_data_fp32) | ||||||
|  |                     state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) | ||||||
|  |                 else: | ||||||
|  |                     state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) | ||||||
|  |                     state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) | ||||||
|  |  | ||||||
|  |                 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||||||
|  |                 beta1, beta2 = group['betas'] | ||||||
|  |  | ||||||
|  |                 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | ||||||
|  |                 exp_avg.mul_(beta1).add_(1 - beta1, grad) | ||||||
|  |  | ||||||
|  |                 state['step'] += 1 | ||||||
|  |                 buffered = group['buffer'][int(state['step'] % 10)] | ||||||
|  |                 if state['step'] == buffered[0]: | ||||||
|  |                     N_sma, step_size = buffered[1], buffered[2] | ||||||
|  |                 else: | ||||||
|  |                     buffered[0] = state['step'] | ||||||
|  |                     beta2_t = beta2 ** state['step'] | ||||||
|  |                     N_sma_max = 2 / (1 - beta2) - 1 | ||||||
|  |                     N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) | ||||||
|  |                     buffered[1] = N_sma | ||||||
|  |  | ||||||
|  |                     # more conservative since it's an approximated value | ||||||
|  |                     if N_sma >= 5: | ||||||
|  |                         step_size = math.sqrt( | ||||||
|  |                             (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( | ||||||
|  |                                     N_sma_max - 2)) / (1 - beta1 ** state['step']) | ||||||
|  |                     elif self.degenerated_to_sgd: | ||||||
|  |                         step_size = 1.0 / (1 - beta1 ** state['step']) | ||||||
|  |                     else: | ||||||
|  |                         step_size = -1 | ||||||
|  |                     buffered[2] = step_size | ||||||
|  |  | ||||||
|  |                 # more conservative since it's an approximated value | ||||||
|  |                 if N_sma >= 5: | ||||||
|  |                     if group['weight_decay'] != 0: | ||||||
|  |                         p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) | ||||||
|  |                     denom = exp_avg_sq.sqrt().add_(group['eps']) | ||||||
|  |                     p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) | ||||||
|  |                     p.data.copy_(p_data_fp32) | ||||||
|  |                 elif step_size > 0: | ||||||
|  |                     if group['weight_decay'] != 0: | ||||||
|  |                         p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) | ||||||
|  |                     p_data_fp32.add_(-step_size * group['lr'], exp_avg) | ||||||
|  |                     p.data.copy_(p_data_fp32) | ||||||
|  |  | ||||||
|  |         return loss | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PlainRAdam(Optimizer): | ||||||
|  |  | ||||||
|  |     def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): | ||||||
|  |         if not 0.0 <= lr: | ||||||
|  |             raise ValueError("Invalid learning rate: {}".format(lr)) | ||||||
|  |         if not 0.0 <= eps: | ||||||
|  |             raise ValueError("Invalid epsilon value: {}".format(eps)) | ||||||
|  |         if not 0.0 <= betas[0] < 1.0: | ||||||
|  |             raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||||||
|  |         if not 0.0 <= betas[1] < 1.0: | ||||||
|  |             raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||||||
|  |  | ||||||
|  |         self.degenerated_to_sgd = degenerated_to_sgd | ||||||
|  |         defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | ||||||
|  |  | ||||||
|  |         super(PlainRAdam, self).__init__(params, defaults) | ||||||
|  |  | ||||||
|  |     def __setstate__(self, state): | ||||||
|  |         super(PlainRAdam, self).__setstate__(state) | ||||||
|  |  | ||||||
|  |     def step(self, closure=None): | ||||||
|  |  | ||||||
|  |         loss = None | ||||||
|  |         if closure is not None: | ||||||
|  |             loss = closure() | ||||||
|  |  | ||||||
|  |         for group in self.param_groups: | ||||||
|  |  | ||||||
|  |             for p in group['params']: | ||||||
|  |                 if p.grad is None: | ||||||
|  |                     continue | ||||||
|  |                 grad = p.grad.data.float() | ||||||
|  |                 if grad.is_sparse: | ||||||
|  |                     raise RuntimeError('RAdam does not support sparse gradients') | ||||||
|  |  | ||||||
|  |                 p_data_fp32 = p.data.float() | ||||||
|  |  | ||||||
|  |                 state = self.state[p] | ||||||
|  |  | ||||||
|  |                 if len(state) == 0: | ||||||
|  |                     state['step'] = 0 | ||||||
|  |                     state['exp_avg'] = torch.zeros_like(p_data_fp32) | ||||||
|  |                     state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) | ||||||
|  |                 else: | ||||||
|  |                     state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) | ||||||
|  |                     state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) | ||||||
|  |  | ||||||
|  |                 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||||||
|  |                 beta1, beta2 = group['betas'] | ||||||
|  |  | ||||||
|  |                 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | ||||||
|  |                 exp_avg.mul_(beta1).add_(1 - beta1, grad) | ||||||
|  |  | ||||||
|  |                 state['step'] += 1 | ||||||
|  |                 beta2_t = beta2 ** state['step'] | ||||||
|  |                 N_sma_max = 2 / (1 - beta2) - 1 | ||||||
|  |                 N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) | ||||||
|  |  | ||||||
|  |                 # more conservative since it's an approximated value | ||||||
|  |                 if N_sma >= 5: | ||||||
|  |                     if group['weight_decay'] != 0: | ||||||
|  |                         p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) | ||||||
|  |                     step_size = group['lr'] * math.sqrt( | ||||||
|  |                         (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( | ||||||
|  |                                 N_sma_max - 2)) / (1 - beta1 ** state['step']) | ||||||
|  |                     denom = exp_avg_sq.sqrt().add_(group['eps']) | ||||||
|  |                     p_data_fp32.addcdiv_(-step_size, exp_avg, denom) | ||||||
|  |                     p.data.copy_(p_data_fp32) | ||||||
|  |                 elif self.degenerated_to_sgd: | ||||||
|  |                     if group['weight_decay'] != 0: | ||||||
|  |                         p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) | ||||||
|  |                     step_size = group['lr'] / (1 - beta1 ** state['step']) | ||||||
|  |                     p_data_fp32.add_(-step_size, exp_avg) | ||||||
|  |                     p.data.copy_(p_data_fp32) | ||||||
|  |  | ||||||
|  |         return loss | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AdamW(Optimizer): | ||||||
|  |  | ||||||
|  |     def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): | ||||||
|  |         if not 0.0 <= lr: | ||||||
|  |             raise ValueError("Invalid learning rate: {}".format(lr)) | ||||||
|  |         if not 0.0 <= eps: | ||||||
|  |             raise ValueError("Invalid epsilon value: {}".format(eps)) | ||||||
|  |         if not 0.0 <= betas[0] < 1.0: | ||||||
|  |             raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||||||
|  |         if not 0.0 <= betas[1] < 1.0: | ||||||
|  |             raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||||||
|  |  | ||||||
|  |         defaults = dict(lr=lr, betas=betas, eps=eps, | ||||||
|  |                         weight_decay=weight_decay, warmup=warmup) | ||||||
|  |         super(AdamW, self).__init__(params, defaults) | ||||||
|  |  | ||||||
|  |     def __setstate__(self, state): | ||||||
|  |         super(AdamW, self).__setstate__(state) | ||||||
|  |  | ||||||
|  |     def step(self, closure=None): | ||||||
|  |         loss = None | ||||||
|  |         if closure is not None: | ||||||
|  |             loss = closure() | ||||||
|  |  | ||||||
|  |         for group in self.param_groups: | ||||||
|  |  | ||||||
|  |             for p in group['params']: | ||||||
|  |                 if p.grad is None: | ||||||
|  |                     continue | ||||||
|  |                 grad = p.grad.data.float() | ||||||
|  |                 if grad.is_sparse: | ||||||
|  |                     raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') | ||||||
|  |  | ||||||
|  |                 p_data_fp32 = p.data.float() | ||||||
|  |  | ||||||
|  |                 state = self.state[p] | ||||||
|  |  | ||||||
|  |                 if len(state) == 0: | ||||||
|  |                     state['step'] = 0 | ||||||
|  |                     state['exp_avg'] = torch.zeros_like(p_data_fp32) | ||||||
|  |                     state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) | ||||||
|  |                 else: | ||||||
|  |                     state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) | ||||||
|  |                     state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) | ||||||
|  |  | ||||||
|  |                 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||||||
|  |                 beta1, beta2 = group['betas'] | ||||||
|  |  | ||||||
|  |                 state['step'] += 1 | ||||||
|  |  | ||||||
|  |                 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | ||||||
|  |                 exp_avg.mul_(beta1).add_(1 - beta1, grad) | ||||||
|  |  | ||||||
|  |                 denom = exp_avg_sq.sqrt().add_(group['eps']) | ||||||
|  |                 bias_correction1 = 1 - beta1 ** state['step'] | ||||||
|  |                 bias_correction2 = 1 - beta2 ** state['step'] | ||||||
|  |  | ||||||
|  |                 if group['warmup'] > state['step']: | ||||||
|  |                     scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] | ||||||
|  |                 else: | ||||||
|  |                     scheduled_lr = group['lr'] | ||||||
|  |  | ||||||
|  |                 step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 | ||||||
|  |  | ||||||
|  |                 if group['weight_decay'] != 0: | ||||||
|  |                     p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) | ||||||
|  |  | ||||||
|  |                 p_data_fp32.addcdiv_(-step_size, exp_avg, denom) | ||||||
|  |  | ||||||
|  |                 p.data.copy_(p_data_fp32) | ||||||
|  |  | ||||||
|  |         return loss | ||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri