LoRA transpose

This commit is contained in:
Varuna Jayasiri
2024-08-18 14:37:14 +05:30
parent ce21dcf76c
commit 9dd97ff11a
2 changed files with 60 additions and 57 deletions

File diff suppressed because one or more lines are too long

View File

@ -42,7 +42,7 @@ class Linear(nn.Module):
$\Delta W$ is initialized to be zero at the beginning of the training. $\Delta W$ is initialized to be zero at the beginning of the training.
They multiple $\Delta W x$ by $\frac{\alpha}{r}$ where $\alpha$ is a hyper-parameter. They multiple $x \Delta W^T$ by $\frac{\alpha}{r}$ where $\alpha$ is a hyper-parameter.
Once $\alpha$ is tuned it can be kept the same when varying $r$. Once $\alpha$ is tuned it can be kept the same when varying $r$.
""" """
@ -77,9 +77,9 @@ class Linear(nn.Module):
# scaling factor $\frac{\alpha}{r}$ # scaling factor $\frac{\alpha}{r}$
self.scaling = alpha / r self.scaling = alpha / r
# Matrix $A \in \mathbb{R}^{r \times k}$ # Matrix $A \in \mathbb{R}^{r \times k}$
self.lora_a = nn.Parameter(torch.empty((in_features, r))) self.lora_a = nn.Parameter(torch.empty((r, in_features)))
# Matrix $B \in \mathbb{R}^{d \times r}$, we keep $A$ and $B$ transposed # Matrix $B \in \mathbb{R}^{d \times r}$, we keep $A$ and $B$ transposed
self.lora_b = nn.Parameter(torch.empty((r, out_features))) self.lora_b = nn.Parameter(torch.empty((outfeatures, r)))
with torch.no_grad(): with torch.no_grad():
# Initialize $A$ similar to a weight matrix in a normal linear layer # Initialize $A$ similar to a weight matrix in a normal linear layer
@ -88,11 +88,11 @@ class Linear(nn.Module):
nn.init.zeros_(self.lora_b) nn.init.zeros_(self.lora_b)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# Compute $W_0 x + b_0$ # Compute $x W_0^T + b_0$
result = nn.functional.linear(x, self.weight, bias=self.bias) result = nn.functional.linear(x, self.weight, bias=self.bias)
# Add $\frac{\alpha}{r} \Delta W x = \frac{\alpha}{r} BAx$ # Add $\frac{\alpha}{r} x \Delta W^T = \frac{\alpha}{r} x {(BA)}^T = \frac{\alpha}{r} x A^T B^T$
result += (x @ self.lora_a @ self.lora_b) * self.scaling result += (x @ self.lora_a.T @ self.lora_b.T) * self.scaling
# #
return result return result
@ -123,16 +123,17 @@ class Embedding(nn.Module):
if alpha is None: if alpha is None:
alpha = r alpha = r
# The pre-trained embedding weights $W_0$ (frozen) nn.Embedding
# The pre-trained embedding weights $W_0^T$ (frozen)
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim))) self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
self.weight.requires_grad = False self.weight.requires_grad = False
# scaling factor $\frac{\alpha}{r}$ # scaling factor $\frac{\alpha}{r}$
self.scaling = alpha / r self.scaling = alpha / r
# Matrix $A \in \mathbb{R}^{r \times k}$ # Matrix $A \in \mathbb{R}^{r \times k}$
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r))) self.lora_a = nn.Parameter(torch.empty((r, num_embeddings)))
# Matrix $B \in \mathbb{R}^{d \times r}$ # Matrix $B \in \mathbb{R}^{d \times r}$
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim))) self.lora_b = nn.Parameter(torch.empty((embedding_dim, r)))
with torch.no_grad(): with torch.no_grad():
# Initialize $A$ with a normal distribution # Initialize $A$ with a normal distribution
@ -141,11 +142,11 @@ class Embedding(nn.Module):
nn.init.zeros_(self.lora_b) nn.init.zeros_(self.lora_b)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# Compute the embeddings $W_0 \text{onehot}(x)$ # Compute the embeddings $\text{onehot}(x) W_0$
result = nn.functional.embedding(x, self.weight) result = nn.functional.embedding(x, self.weight)
# Add $\frac{\alpha}{r} \Delta W \text{onehot}(x) = \frac{\alpha}{r} BA \text{onehot}(x_$ # Add $\frac{\alpha}{r} \text{onehot}(x) \Delta W^T = \frac{\alpha}{r} \text{onehot}(x) A^T B^T$
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling result += (nn.functional.embedding(x, self.lora_a.T) @ self.lora_b.T) * self.scaling
# #
return result return result