mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
190 lines
7.2 KiB
Python
190 lines
7.2 KiB
Python
"""
|
|
---
|
|
title: Finetune GPT-2 with LoRA
|
|
summary: This is training code with notes for fine-tuning pre-trained GPT-2 model with LoRA.
|
|
---
|
|
|
|
# Finetune [GPT-2](gpt2.html) with [LoRA](index.html)
|
|
|
|
Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.
|
|
|
|
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/lora/experiment.ipynb)
|
|
"""
|
|
|
|
import torch
|
|
from torch.optim import Adam
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
from labml import lab, monit, tracker
|
|
from labml.configs import BaseConfigs, option
|
|
from labml.utils.download import download_file
|
|
from labml_helpers.device import DeviceConfigs
|
|
from labml_nn.lora.gpt2 import GPTModel
|
|
|
|
|
|
class Trainer(BaseConfigs):
|
|
"""
|
|
## Trainer configurations and the training loop
|
|
|
|
The default configs can and will be over-ridden when we start the experiment
|
|
"""
|
|
device: torch.device = DeviceConfigs()
|
|
|
|
# GPT-2 configs
|
|
layer_norm_epsilon: float = 1e-05
|
|
d_model: int = 768
|
|
n_layers: int = 12
|
|
n_heads: int = 12
|
|
n_positions: int = 1024
|
|
vocab_size: int = 50257
|
|
|
|
# Training configs
|
|
epochs: int = 10
|
|
batch_size: int = 32
|
|
learning_rate: float = 1e-4
|
|
context_len: int = 512
|
|
|
|
# LoRA rank
|
|
lora_r: int = 32
|
|
|
|
# Dataset
|
|
text: TensorDataset = "tiny_shakespeare"
|
|
# Huggingface tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
# [GPT2 model](gpt2.html)
|
|
model: GPTModel
|
|
# Optimizer
|
|
optimizer: torch.optim.Adam
|
|
# Cross entropy loss
|
|
loss_func = torch.nn.CrossEntropyLoss()
|
|
# Dataloader
|
|
data_loader: DataLoader
|
|
|
|
def _load_pretrained_weights(self):
|
|
"""
|
|
### Load pre-trained [GPT-2 from huggingface](https://huggingface.co/openai-community/gpt2)
|
|
"""
|
|
|
|
# Load the huggingface model and get the parameters
|
|
hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
|
|
state_dict = hf_model.state_dict()
|
|
|
|
# Transformer embedding and prediction layer parameter mapping (`hf: ours`)
|
|
mapping = {
|
|
'transformer.wte.weight': 'token_embedding.weight',
|
|
'transformer.wpe.weight': 'position_embedding.weight',
|
|
'transformer.ln_f.weight': 'final_norm.weight',
|
|
'transformer.ln_f.bias': 'final_norm.bias',
|
|
'lm_head.weight': 'lm_head.weight'
|
|
}
|
|
|
|
# Mapping (`hf: ours`) of decoder layers
|
|
for i in range(12):
|
|
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
|
|
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
|
|
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
|
|
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
|
|
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
|
|
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
|
|
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
|
|
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
|
|
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
|
|
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
|
|
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
|
|
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.linear_out.bias'
|
|
|
|
# Move the parameters based on mapping
|
|
new_state_dict = {}
|
|
for old_key, new_key in mapping.items():
|
|
if old_key in state_dict:
|
|
new_state_dict[new_key] = state_dict[old_key]
|
|
|
|
# GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers
|
|
convo_layers = ([f'blocks.{i}.ffn.linear_in.weight' for i in range(12)] +
|
|
[f'blocks.{i}.ffn.linear_out.weight' for i in range(12)] +
|
|
[f'blocks.{i}.attn.qkv_projection.weight' for i in range(12)] +
|
|
[f'blocks.{i}.attn.output_projection.weight' for i in range(12)])
|
|
|
|
for layer in convo_layers:
|
|
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
|
|
|
|
# Load out model. We use `strict = False` because the state does not have LoRA weights
|
|
missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
|
|
|
|
# make sure that only lora weights are not loaded
|
|
assert all('lora' in key for key in missing_keys)
|
|
assert not unexpected_keys
|
|
|
|
def initialize(self):
|
|
"""
|
|
### Initialize the model, optimizer and dataloader
|
|
"""
|
|
# Initialize the [GPT2 model](gpt2.html)
|
|
self.model = GPTModel(
|
|
layer_norm_epsilon=self.layer_norm_epsilon,
|
|
d_model=self.d_model,
|
|
n_layers=self.n_layers,
|
|
n_heads=self.n_heads,
|
|
n_positions=self.n_positions,
|
|
vocab_size=self.vocab_size,
|
|
r=self.lora_r,
|
|
)
|
|
self.model.to(self.device)
|
|
# Load pre-trained model weights
|
|
self._load_pretrained_weights()
|
|
|
|
# Initialize the optimizer
|
|
self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)
|
|
|
|
# Initialize the data loader
|
|
self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)
|
|
|
|
def run(self):
|
|
"""
|
|
### Training loop
|
|
"""
|
|
|
|
for _ in monit.loop(self.epochs):
|
|
# `inputs` has shape `[batch_size, seq_len]`
|
|
for (inputs,) in monit.iterate('Train', self.data_loader):
|
|
# Move `inputs` to device
|
|
inputs = inputs.to(self.device)
|
|
# Call the model, with the all but the last token
|
|
logits = self.model(inputs[:, :-1])
|
|
# Get cross entropy loss
|
|
loss = self.loss_func(logits.reshape(-1, logits.shape[-1]), inputs[:, 1:].reshape(-1))
|
|
|
|
# Make gradients 0
|
|
self.optimizer.zero_grad()
|
|
# Compute gradients
|
|
loss.backward()
|
|
# Optimize
|
|
self.optimizer.step()
|
|
|
|
# Log the loss
|
|
tracker.save({'loss': loss})
|
|
tracker.add_global_step()
|
|
#
|
|
tracker.new_line()
|
|
|
|
|
|
@option(Trainer.text)
|
|
def tiny_shakespeare(c: Trainer):
|
|
"""
|
|
### Tiny Shakespeare dataset
|
|
|
|
It will download from the url if not present
|
|
"""
|
|
path = lab.get_data_path() / 'tiny_shakespeare.txt'
|
|
if not path.exists():
|
|
download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
text = f.read()
|
|
|
|
tokens = c.tokenizer.encode(text)
|
|
num_batches = len(tokens) // (c.batch_size * c.context_len)
|
|
tokens = tokens[:num_batches * c.batch_size * c.context_len]
|
|
input_ids = torch.tensor(tokens).view(-1, c.context_len)
|
|
return TensorDataset(input_ids)
|