mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			56 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			56 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import List, Dict
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| 
 | |
| from labml_nn.neox.model import TransformerLayer, NeoXModule
 | |
| 
 | |
| 
 | |
| class FineTuner:
 | |
|     def __init__(self, layers: List[NeoXModule]):
 | |
|         self.layers = layers
 | |
| 
 | |
|     def get_trainable_params(self) -> Dict[str, nn.Parameter]:
 | |
|         params = {}
 | |
|         for i, layer in enumerate(self.layers):
 | |
|             params.update(self.get_layer_trainable_params(layer, prefix=f'layer_{i :02d}'))
 | |
| 
 | |
|         return params
 | |
| 
 | |
|     def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def set_trainable_params(self):
 | |
|         for layer in self.layers:
 | |
|             # Set `requires_grad` to `False` for the entire layer.
 | |
|             layer.requires_grad_(False)
 | |
|             #
 | |
|             for p in self.get_trainable_params().values():
 | |
|                 p.requires_grad_(True)
 | |
| 
 | |
|     def state_dict(self):
 | |
|         return {n: p.data.cpu() for n, p in self.get_trainable_params().items()}
 | |
| 
 | |
|     def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
 | |
|         params = self.get_trainable_params()
 | |
|         for n, p in params.items():
 | |
|             p.data[:] = state_dict[n].to(p.data.device)
 | |
| 
 | |
|         for n in state_dict.keys():
 | |
|             assert n in params, n
 | |
| 
 | |
| 
 | |
| class FineTuneBiases(FineTuner):
 | |
|     def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
 | |
|         params = {}
 | |
| 
 | |
|         if isinstance(layer, TransformerLayer):
 | |
|             # No need to train the mlp bias because we are adding it with attention output
 | |
|             params[f'{prefix}.attention.output.bias'] = layer.attention.output.bias
 | |
|             params[f'{prefix}.attention.qkv_lin.bias'] = layer.attention.qkv_lin.bias
 | |
|             params[f'{prefix}.ffn.dense_h_h4.bias'] = layer.ffn.dense_h_h4.bias
 | |
|         else:
 | |
|             pass
 | |
| 
 | |
|         return params
 | 
