Merge pull request #266 from lakshith-403/LoRA

This commit is contained in:
Varuna Jayasiri
2024-07-31 21:06:28 +05:30
committed by GitHub
5 changed files with 554 additions and 0 deletions

View File

@ -0,0 +1,130 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from labml_nn.transformers.LoRA import Linear, Embedding
tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = {
"layer_norm_epsilon": 1e-05,
"n_embd": 768,
"n_head": 12,
"n_layer": 12,
"n_positions": 1024,
"vocab_size": 50257,
"device": "cuda"
}
class FFN(nn.Module):
def __init__(self, dim):
super().__init__()
self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True)
self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True)
self.act = nn.functional.gelu
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states
class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.embed_dim = config['n_embd']
self.num_heads = config['n_head']
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True)
self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True)
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, hidden_states):
batch_size, seq_length, _ = hidden_states.size()
query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=True, # for the triangular mask
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
attn_output = self.c_proj(attn_output)
return attn_output
class Block(nn.Module):
def __init__(self):
super().__init__()
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
self.attn = MultiHeadAttention()
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
self.ffn = FFN(config['n_embd'] * 4)
def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.pre_norm(hidden_states)
attn_output = self.attn(hidden_states)
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.post_norm(hidden_states)
feed_forward_output = self.ffn(hidden_states)
hidden_states = feed_forward_output + residual
return hidden_states
class GPTModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32)
self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32)
self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False)
def forward(self, input_ids):
batch_size, input_shape = input_ids.size()
token_embeddings = self.token_embedding(input_ids) # B T C
position_ids = torch.arange(input_shape, device=config['device']) # T C
position_embeddings = self.position_embedding(position_ids) # B T C
hidden_states = token_embeddings + position_embeddings
for block in self.blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_norm(hidden_states)
logits = self.lm_head(hidden_states)
return logits

View File

@ -0,0 +1,68 @@
import torch
import torch.nn as nn
class Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
r: int,
alpha: int = None):
if alpha is None:
alpha = r
super().__init__()
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
self.weight.requires_grad = False
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
self.bias.requires_grad = False
else:
self.bias = None
self.scaling = alpha / r
self.lora_a = nn.Parameter(torch.empty((in_features, r)))
self.lora_b = nn.Parameter(torch.empty((r, out_features)))
with torch.no_grad():
nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
nn.init.zeros_(self.lora_b)
def forward(self, x: torch.Tensor):
result = nn.functional.linear(x, self.weight, bias=self.bias)
result += (x @ self.lora_a @ self.lora_b) * self.scaling
return result
class Embedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int,
alpha: int = None,
):
if alpha is None:
alpha = r
super().__init__()
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
self.weight.requires_grad = False
self.scaling = alpha / r
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))
with torch.no_grad():
nn.init.normal_(self.lora_a)
nn.init.zeros_(self.lora_b)
def forward(self, x: torch.Tensor):
result = nn.functional.embedding(x, self.weight)
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling
return result

View File

@ -0,0 +1,97 @@
{
"cells": [
{
"metadata": {},
"cell_type": "code",
"source": [
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
"import torch"
],
"id": "cffa3ec341b4905a",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
],
"id": "c2b0b7e18394ea9e",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true
},
"source": [
"model = GPTModel()\n",
"\n",
"state_dict = torch.load('transformed.pth')\n",
"\n",
"missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n",
"if missing_keys:\n",
" print(f\"Missing keys: {missing_keys}\")\n",
"if unexpected_keys:\n",
" print(f\"Unexpected keys: {unexpected_keys}\")"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"prompt = \"hello how are you\"\n",
"tokenized = tokenizer(prompt, return_tensors=\"pt\")\n",
"tokenized['input_ids'] = tokenized['input_ids'].to('cuda')\n",
"model = model.to('cuda')\n",
"\n",
"with torch.no_grad():\n",
" model.eval()\n",
" res = model(tokenized['input_ids'])\n",
"\n",
"output_ids = torch.argmax(res, dim=-1)\n",
"for id in output_ids[0]:\n",
" print(tokenizer.decode(id))"
],
"id": "f4f7826ec3729b66",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "c12776360008a974",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,44 @@
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
state_dict = model.state_dict()
mapping = {
'transformer.wte.weight': 'token_embedding.weight',
'transformer.wpe.weight': 'position_embedding.weight',
'transformer.ln_f.weight': 'final_norm.weight',
'transformer.ln_f.bias': 'final_norm.bias',
'lm_head.weight': 'lm_head.weight'
}
for i in range(12):
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
new_state_dict = {}
for old_key, new_key in mapping.items():
if old_key in state_dict:
new_state_dict[new_key] = state_dict[old_key]
# transpose weight matrices of convo 1d layers to use linear layers instead
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
for layer in convo_layers:
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
torch.save(new_state_dict, 'transformed.pth')

View File

@ -0,0 +1,215 @@
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
}
},
"source": "# !wget https://raw.github/zusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "3b1e507015ba6b81",
"metadata": {},
"source": [
"with open('input.txt', 'r', encoding='utf-8') as f:\n",
" text = f.read()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "ac8e51ae5bbfcae7",
"metadata": {},
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
"\n",
"tokens = tokenizer.encode(text, add_special_tokens=False)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "aeefcdf813e427e",
"metadata": {},
"source": [
"context_length = 512\n",
"batch_size = 2"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "a384b42274f008a2",
"metadata": {},
"source": [
"num_batches = len(tokens) // (batch_size * context_length)\n",
"tokens = tokens[:num_batches * batch_size * context_length]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "5c4cc78ac1a02c1d",
"metadata": {},
"source": [
"import torch\n",
"\n",
"input_ids = torch.tensor(tokens).view(-1, context_length)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "7037fd75e2161382",
"metadata": {},
"source": [
"from torch.utils.data import DataLoader, TensorDataset\n",
"from torch.optim import Adam\n",
"from torch.utils.data import random_split\n",
"\n",
"dataset = TensorDataset(input_ids)\n",
"\n",
"train_ratio = 0.8\n",
"test_ratio = 0.2\n",
"\n",
"train_size = int(train_ratio * len(dataset))\n",
"test_size = len(dataset) - train_size\n",
"\n",
"train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
"\n",
"train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
"test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "a98b7baa064b8494",
"metadata": {},
"source": [
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
"\n",
"model = GPTModel()\n",
"state_dict = torch.load('transformed.pth', weights_only=True)\n",
"\n",
"_ = model.load_state_dict(state_dict, strict=False)"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"device = \"cuda\"\n",
"model = model.to(device=\"cuda\")"
],
"id": "2e0fa8b3082df716",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "e2f5076894770740",
"metadata": {},
"source": [
"from labml import tracker, experiment\n",
"\n",
"optimizer = Adam(model.parameters(), lr=5e-5)\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"\n",
"model.train()\n",
"epochs = 3\n",
"step = 0\n",
"\n",
"with experiment.record(name='LoRA.GPT2', app_url='http://localhost:5005/api/v1/track'):\n",
" for epoch in range(epochs):\n",
" for batch in train_dataloader:\n",
" inputs = batch[0]\n",
" inputs = inputs.to(device)\n",
" labels = inputs.clone()\n",
" \n",
" outputs = model(inputs)\n",
" \n",
" shift_logits = outputs[..., :-1, :]\n",
" shift_labels = labels[..., 1:]\n",
" \n",
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" tracker.save(step, {'loss': loss})\n",
" step += 1\n",
" print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n",
" \n",
" test_loss = 0\n",
" for batch in test_dataloader:\n",
" inputs = batch[0]\n",
" inputs = inputs.to(device)\n",
" labels = inputs.clone()\n",
" \n",
" outputs = model(inputs)\n",
" \n",
" shift_logits = outputs[..., :-1, :]\n",
" shift_labels = labels[..., 1:]\n",
" \n",
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
" \n",
" test_loss += loss.item()\n",
" test_loss /= len(test_dataloader)\n",
" tracker.save(step, {'test_loss': test_loss})\n",
" \n",
"\n",
"print(\"Training complete.\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "da2d4023002648dc",
"metadata": {},
"source": [],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "base"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}