mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
lora layers
This commit is contained in:
68
docs/transformers/LoRA/__init__.py
Normal file
68
docs/transformers/LoRA/__init__.py
Normal file
@ -0,0 +1,68 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
r: int,
|
||||
alpha: int = None):
|
||||
if alpha is None:
|
||||
alpha = r
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||
self.weight.requires_grad = False
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features))
|
||||
self.bias.requires_grad = False
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.scaling = alpha / r
|
||||
self.lora_a = nn.Parameter(torch.empty((in_features, r)))
|
||||
self.lora_b = nn.Parameter(torch.empty((r, out_features)))
|
||||
|
||||
with torch.no_grad():
|
||||
nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
|
||||
nn.init.zeros_(self.lora_b)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = nn.functional.linear(x, self.weight, bias=self.bias)
|
||||
|
||||
result += (x @ self.lora_a @ self.lora_b) * self.scaling
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
r: int,
|
||||
alpha: int = None,
|
||||
):
|
||||
if alpha is None:
|
||||
alpha = r
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
|
||||
self.weight.requires_grad = False
|
||||
|
||||
self.scaling = alpha / self.r
|
||||
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
|
||||
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))
|
||||
|
||||
with torch.no_grad():
|
||||
nn.init.normal_(self.lora_a)
|
||||
nn.init.zeros_(self.lora_b)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = nn.functional.embedding(x, self.weight)
|
||||
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling
|
||||
|
||||
return result
|
Reference in New Issue
Block a user