mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
Merge pull request #266 from lakshith-403/LoRA
This commit is contained in:
130
labml_nn/transformers/LoRA/GPT2.py
Normal file
130
labml_nn/transformers/LoRA/GPT2.py
Normal file
@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer
|
||||
from labml_nn.transformers.LoRA import Linear, Embedding
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
config = {
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"n_embd": 768,
|
||||
"n_head": 12,
|
||||
"n_layer": 12,
|
||||
"n_positions": 1024,
|
||||
"vocab_size": 50257,
|
||||
"device": "cuda"
|
||||
}
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True)
|
||||
self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True)
|
||||
self.act = nn.functional.gelu
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed_dim = config['n_embd']
|
||||
self.num_heads = config['n_head']
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.split_size = self.embed_dim
|
||||
|
||||
self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True)
|
||||
self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True)
|
||||
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
Splits hidden_size dim into attn_head_size and num_heads
|
||||
"""
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
batch_size, seq_length, _ = hidden_states.size()
|
||||
|
||||
query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=True, # for the triangular mask
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
|
||||
self.attn = MultiHeadAttention()
|
||||
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
|
||||
self.ffn = FFN(config['n_embd'] * 4)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_norm(hidden_states)
|
||||
|
||||
attn_output = self.attn(hidden_states)
|
||||
|
||||
hidden_states = attn_output + residual
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_norm(hidden_states)
|
||||
feed_forward_output = self.ffn(hidden_states)
|
||||
hidden_states = feed_forward_output + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32)
|
||||
self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32)
|
||||
|
||||
self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
|
||||
|
||||
self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
|
||||
|
||||
self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False)
|
||||
|
||||
def forward(self, input_ids):
|
||||
batch_size, input_shape = input_ids.size()
|
||||
|
||||
token_embeddings = self.token_embedding(input_ids) # B T C
|
||||
position_ids = torch.arange(input_shape, device=config['device']) # T C
|
||||
position_embeddings = self.position_embedding(position_ids) # B T C
|
||||
|
||||
hidden_states = token_embeddings + position_embeddings
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_norm(hidden_states)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return logits
|
68
labml_nn/transformers/LoRA/__init__.py
Normal file
68
labml_nn/transformers/LoRA/__init__.py
Normal file
@ -0,0 +1,68 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
r: int,
|
||||
alpha: int = None):
|
||||
if alpha is None:
|
||||
alpha = r
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||
self.weight.requires_grad = False
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features))
|
||||
self.bias.requires_grad = False
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.scaling = alpha / r
|
||||
self.lora_a = nn.Parameter(torch.empty((in_features, r)))
|
||||
self.lora_b = nn.Parameter(torch.empty((r, out_features)))
|
||||
|
||||
with torch.no_grad():
|
||||
nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
|
||||
nn.init.zeros_(self.lora_b)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = nn.functional.linear(x, self.weight, bias=self.bias)
|
||||
|
||||
result += (x @ self.lora_a @ self.lora_b) * self.scaling
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
r: int,
|
||||
alpha: int = None,
|
||||
):
|
||||
if alpha is None:
|
||||
alpha = r
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
|
||||
self.weight.requires_grad = False
|
||||
|
||||
self.scaling = alpha / r
|
||||
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
|
||||
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))
|
||||
|
||||
with torch.no_grad():
|
||||
nn.init.normal_(self.lora_a)
|
||||
nn.init.zeros_(self.lora_b)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = nn.functional.embedding(x, self.weight)
|
||||
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling
|
||||
|
||||
return result
|
97
labml_nn/transformers/LoRA/experiment.ipynb
Normal file
97
labml_nn/transformers/LoRA/experiment.ipynb
Normal file
@ -0,0 +1,97 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
|
||||
"import torch"
|
||||
],
|
||||
"id": "cffa3ec341b4905a",
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
|
||||
],
|
||||
"id": "c2b0b7e18394ea9e",
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "initial_id",
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"source": [
|
||||
"model = GPTModel()\n",
|
||||
"\n",
|
||||
"state_dict = torch.load('transformed.pth')\n",
|
||||
"\n",
|
||||
"missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n",
|
||||
"if missing_keys:\n",
|
||||
" print(f\"Missing keys: {missing_keys}\")\n",
|
||||
"if unexpected_keys:\n",
|
||||
" print(f\"Unexpected keys: {unexpected_keys}\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"prompt = \"hello how are you\"\n",
|
||||
"tokenized = tokenizer(prompt, return_tensors=\"pt\")\n",
|
||||
"tokenized['input_ids'] = tokenized['input_ids'].to('cuda')\n",
|
||||
"model = model.to('cuda')\n",
|
||||
"\n",
|
||||
"with torch.no_grad():\n",
|
||||
" model.eval()\n",
|
||||
" res = model(tokenized['input_ids'])\n",
|
||||
"\n",
|
||||
"output_ids = torch.argmax(res, dim=-1)\n",
|
||||
"for id in output_ids[0]:\n",
|
||||
" print(tokenizer.decode(id))"
|
||||
],
|
||||
"id": "f4f7826ec3729b66",
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": "",
|
||||
"id": "c12776360008a974",
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
44
labml_nn/transformers/LoRA/load_hf.py
Normal file
44
labml_nn/transformers/LoRA/load_hf.py
Normal file
@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
mapping = {
|
||||
'transformer.wte.weight': 'token_embedding.weight',
|
||||
'transformer.wpe.weight': 'position_embedding.weight',
|
||||
'transformer.ln_f.weight': 'final_norm.weight',
|
||||
'transformer.ln_f.bias': 'final_norm.bias',
|
||||
'lm_head.weight': 'lm_head.weight'
|
||||
}
|
||||
|
||||
for i in range(12):
|
||||
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
|
||||
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
|
||||
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
|
||||
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
|
||||
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
|
||||
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
|
||||
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
|
||||
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
|
||||
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
|
||||
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
|
||||
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
|
||||
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
|
||||
|
||||
new_state_dict = {}
|
||||
for old_key, new_key in mapping.items():
|
||||
if old_key in state_dict:
|
||||
new_state_dict[new_key] = state_dict[old_key]
|
||||
|
||||
# transpose weight matrices of convo 1d layers to use linear layers instead
|
||||
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
|
||||
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
|
||||
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
|
||||
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
|
||||
|
||||
for layer in convo_layers:
|
||||
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
|
||||
|
||||
torch.save(new_state_dict, 'transformed.pth')
|
215
labml_nn/transformers/LoRA/train.ipynb
Normal file
215
labml_nn/transformers/LoRA/train.ipynb
Normal file
@ -0,0 +1,215 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "initial_id",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"jupyter": {
|
||||
"outputs_hidden": true
|
||||
}
|
||||
},
|
||||
"source": "# !wget https://raw.github/zusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "3b1e507015ba6b81",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"with open('input.txt', 'r', encoding='utf-8') as f:\n",
|
||||
" text = f.read()"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "ac8e51ae5bbfcae7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
|
||||
"\n",
|
||||
"tokens = tokenizer.encode(text, add_special_tokens=False)"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "aeefcdf813e427e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"context_length = 512\n",
|
||||
"batch_size = 2"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "a384b42274f008a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"num_batches = len(tokens) // (batch_size * context_length)\n",
|
||||
"tokens = tokens[:num_batches * batch_size * context_length]"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "5c4cc78ac1a02c1d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"input_ids = torch.tensor(tokens).view(-1, context_length)"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "7037fd75e2161382",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader, TensorDataset\n",
|
||||
"from torch.optim import Adam\n",
|
||||
"from torch.utils.data import random_split\n",
|
||||
"\n",
|
||||
"dataset = TensorDataset(input_ids)\n",
|
||||
"\n",
|
||||
"train_ratio = 0.8\n",
|
||||
"test_ratio = 0.2\n",
|
||||
"\n",
|
||||
"train_size = int(train_ratio * len(dataset))\n",
|
||||
"test_size = len(dataset) - train_size\n",
|
||||
"\n",
|
||||
"train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
|
||||
"\n",
|
||||
"train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
||||
"test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "a98b7baa064b8494",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
|
||||
"\n",
|
||||
"model = GPTModel()\n",
|
||||
"state_dict = torch.load('transformed.pth', weights_only=True)\n",
|
||||
"\n",
|
||||
"_ = model.load_state_dict(state_dict, strict=False)"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"device = \"cuda\"\n",
|
||||
"model = model.to(device=\"cuda\")"
|
||||
],
|
||||
"id": "2e0fa8b3082df716",
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "e2f5076894770740",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"from labml import tracker, experiment\n",
|
||||
"\n",
|
||||
"optimizer = Adam(model.parameters(), lr=5e-5)\n",
|
||||
"criterion = torch.nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
"model.train()\n",
|
||||
"epochs = 3\n",
|
||||
"step = 0\n",
|
||||
"\n",
|
||||
"with experiment.record(name='LoRA.GPT2', app_url='http://localhost:5005/api/v1/track'):\n",
|
||||
" for epoch in range(epochs):\n",
|
||||
" for batch in train_dataloader:\n",
|
||||
" inputs = batch[0]\n",
|
||||
" inputs = inputs.to(device)\n",
|
||||
" labels = inputs.clone()\n",
|
||||
" \n",
|
||||
" outputs = model(inputs)\n",
|
||||
" \n",
|
||||
" shift_logits = outputs[..., :-1, :]\n",
|
||||
" shift_labels = labels[..., 1:]\n",
|
||||
" \n",
|
||||
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
|
||||
" \n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" \n",
|
||||
" tracker.save(step, {'loss': loss})\n",
|
||||
" step += 1\n",
|
||||
" print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n",
|
||||
" \n",
|
||||
" test_loss = 0\n",
|
||||
" for batch in test_dataloader:\n",
|
||||
" inputs = batch[0]\n",
|
||||
" inputs = inputs.to(device)\n",
|
||||
" labels = inputs.clone()\n",
|
||||
" \n",
|
||||
" outputs = model(inputs)\n",
|
||||
" \n",
|
||||
" shift_logits = outputs[..., :-1, :]\n",
|
||||
" shift_labels = labels[..., 1:]\n",
|
||||
" \n",
|
||||
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
|
||||
" \n",
|
||||
" test_loss += loss.item()\n",
|
||||
" test_loss /= len(test_dataloader)\n",
|
||||
" tracker.save(step, {'test_loss': test_loss})\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"print(\"Training complete.\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "da2d4023002648dc",
|
||||
"metadata": {},
|
||||
"source": [],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "base"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Reference in New Issue
Block a user