Finetune GPT-2 with LoRA

Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.

Open In Colab

14import torch
15from labml import lab, monit, tracker
16from labml.configs import BaseConfigs, option
17from labml.utils.download import download_file
18from labml_helpers.device import DeviceConfigs
19from torch.optim import Adam
20from torch.utils.data import DataLoader, TensorDataset
21from transformers import AutoTokenizer, AutoModelForCausalLM
22from labml_nn.lora.gpt2 import GPTModel

Trainer configurations and the training loop

The default configs can and will be over-ridden when we start the experiment

25class Trainer(BaseConfigs):
31    device: torch.device = DeviceConfigs()

GPT-2 configs

34    layer_norm_epsilon: float = 1e-05
35    n_embed: int = 768
36    n_layer: int = 12
37    n_positions: int = 1024
38    vocab_size: int = 50257

Training configs

41    epochs: int = 10
42    batch_size: int = 32
43    learning_rate: float = 1e-4
44    context_len: int = 512

LoRA rank

47    lora_r: int = 32

Dataset

50    text: TensorDataset = "tiny_shakespeare"
51    tokenizer = AutoTokenizer.from_pretrained("gpt2")
52    model: GPTModel
53    optimizer: torch.optim.Adam
54    criterion = torch.nn.CrossEntropyLoss()
55    data_loader: DataLoader

Load pre-trained GPT-2 from huggingface

57    def _load_pretrained_weights(self):

Load the huggingface model and get the parameters

63        hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
64        state_dict = hf_model.state_dict()

Transformer embedding and prediction layer parameter mapping (hf: ours )

67        mapping = {
68            'transformer.wte.weight': 'token_embedding.weight',
69            'transformer.wpe.weight': 'position_embedding.weight',
70            'transformer.ln_f.weight': 'final_norm.weight',
71            'transformer.ln_f.bias': 'final_norm.bias',
72            'lm_head.weight': 'lm_head.weight'
73        }

Mapping (hf: ours ) of decoder layers

76        for i in range(12):
77            mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
78            mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
79            mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
80            mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
81            mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
82            mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
83            mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
84            mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
85            mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
86            mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
87            mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
88            mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'

Move the parameters based on mapping

91        new_state_dict = {}
92        for old_key, new_key in mapping.items():
93            if old_key in state_dict:
94                new_state_dict[new_key] = state_dict[old_key]

GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers

97        convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
98                        [f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
99                        [f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
100                        [f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
101
102        for layer in convo_layers:
103            new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)

Load out model

106        self.model.load_state_dict(new_state_dict, strict=False)  # state dict does not have lora weights

Initialize the model, optimizer and dataloader

108    def initialize(self):

Initialize the model

113        self.model = GPTModel(
114            layer_norm_epsilon=self.layer_norm_epsilon,
115            n_embd=self.n_embed,
116            n_layer=self.n_layer,
117            n_positions=self.n_positions,
118            vocab_size=self.vocab_size,
119            r=self.lora_r,
120        )
121        self.model.to(self.device)

Load pre-trained model weights

123        self._load_pretrained_weights()

Initialize the optimizer

126        self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)

Initialize the data loader

129        self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)

Training loop

131    def run(self):
136        for _ in monit.loop(self.epochs):
137            for i, batch in monit.enum('Train', self.data_loader):
138                inputs = batch[0]
139                inputs = inputs.to(self.device)
140                labels = inputs.clone()
141
142                outputs = self.model(inputs)
143
144                shift_logits = outputs[..., :-1, :]
145                shift_labels = labels[..., 1:]
146
147                loss = self.criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
148
149                self.optimizer.zero_grad()
150                loss.backward()
151                self.optimizer.step()
152
153                tracker.add({'loss': loss})
154
155                tracker.save()
156                tracker.add_global_step()
157            tracker.new_line()

Tiny Shakespeare dataset

It will download from the url if not present

160@option(Trainer.text)
161def tiny_shakespeare(c: Trainer):
167    path = lab.get_data_path() / 'tiny_shakespeare.txt'
168    if not path.exists():
169        download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
170    with open(path, 'r', encoding='utf-8') as f:
171        text = f.read()
172
173    tokens = c.tokenizer.encode(text)
174    num_batches = len(tokens) // (c.batch_size * c.context_len)
175    tokens = tokens[:num_batches * c.batch_size * c.context_len]
176    input_ids = torch.tensor(tokens).view(-1, c.context_len)
177    return TensorDataset(input_ids)