mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
LoRA transpose
This commit is contained in:
File diff suppressed because one or more lines are too long
@ -42,7 +42,7 @@ class Linear(nn.Module):
|
||||
|
||||
$\Delta W$ is initialized to be zero at the beginning of the training.
|
||||
|
||||
They multiple $\Delta W x$ by $\frac{\alpha}{r}$ where $\alpha$ is a hyper-parameter.
|
||||
They multiple $x \Delta W^T$ by $\frac{\alpha}{r}$ where $\alpha$ is a hyper-parameter.
|
||||
Once $\alpha$ is tuned it can be kept the same when varying $r$.
|
||||
"""
|
||||
|
||||
@ -77,9 +77,9 @@ class Linear(nn.Module):
|
||||
# scaling factor $\frac{\alpha}{r}$
|
||||
self.scaling = alpha / r
|
||||
# Matrix $A \in \mathbb{R}^{r \times k}$
|
||||
self.lora_a = nn.Parameter(torch.empty((in_features, r)))
|
||||
self.lora_a = nn.Parameter(torch.empty((r, in_features)))
|
||||
# Matrix $B \in \mathbb{R}^{d \times r}$, we keep $A$ and $B$ transposed
|
||||
self.lora_b = nn.Parameter(torch.empty((r, out_features)))
|
||||
self.lora_b = nn.Parameter(torch.empty((outfeatures, r)))
|
||||
|
||||
with torch.no_grad():
|
||||
# Initialize $A$ similar to a weight matrix in a normal linear layer
|
||||
@ -88,11 +88,11 @@ class Linear(nn.Module):
|
||||
nn.init.zeros_(self.lora_b)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Compute $W_0 x + b_0$
|
||||
# Compute $x W_0^T + b_0$
|
||||
result = nn.functional.linear(x, self.weight, bias=self.bias)
|
||||
|
||||
# Add $\frac{\alpha}{r} \Delta W x = \frac{\alpha}{r} BAx$
|
||||
result += (x @ self.lora_a @ self.lora_b) * self.scaling
|
||||
# Add $\frac{\alpha}{r} x \Delta W^T = \frac{\alpha}{r} x {(BA)}^T = \frac{\alpha}{r} x A^T B^T$
|
||||
result += (x @ self.lora_a.T @ self.lora_b.T) * self.scaling
|
||||
|
||||
#
|
||||
return result
|
||||
@ -123,16 +123,17 @@ class Embedding(nn.Module):
|
||||
if alpha is None:
|
||||
alpha = r
|
||||
|
||||
# The pre-trained embedding weights $W_0$ (frozen)
|
||||
nn.Embedding
|
||||
# The pre-trained embedding weights $W_0^T$ (frozen)
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
|
||||
self.weight.requires_grad = False
|
||||
|
||||
# scaling factor $\frac{\alpha}{r}$
|
||||
self.scaling = alpha / r
|
||||
# Matrix $A \in \mathbb{R}^{r \times k}$
|
||||
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
|
||||
self.lora_a = nn.Parameter(torch.empty((r, num_embeddings)))
|
||||
# Matrix $B \in \mathbb{R}^{d \times r}$
|
||||
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))
|
||||
self.lora_b = nn.Parameter(torch.empty((embedding_dim, r)))
|
||||
|
||||
with torch.no_grad():
|
||||
# Initialize $A$ with a normal distribution
|
||||
@ -141,11 +142,11 @@ class Embedding(nn.Module):
|
||||
nn.init.zeros_(self.lora_b)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Compute the embeddings $W_0 \text{onehot}(x)$
|
||||
# Compute the embeddings $\text{onehot}(x) W_0$
|
||||
result = nn.functional.embedding(x, self.weight)
|
||||
|
||||
# Add $\frac{\alpha}{r} \Delta W \text{onehot}(x) = \frac{\alpha}{r} BA \text{onehot}(x_$
|
||||
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling
|
||||
# Add $\frac{\alpha}{r} \text{onehot}(x) \Delta W^T = \frac{\alpha}{r} \text{onehot}(x) A^T B^T$
|
||||
result += (nn.functional.embedding(x, self.lora_a.T) @ self.lora_b.T) * self.scaling
|
||||
|
||||
#
|
||||
return result
|
||||
|
Reference in New Issue
Block a user