From f6e913eb09cabca03d7c015867ec4929de8c3d1b Mon Sep 17 00:00:00 2001
From: Varuna Jayasiri
Date: Thu, 27 Jun 2024 19:35:37 +0530
Subject: [PATCH 01/16] transformer mha chinese translation
---
docs/sitemap.xml | 2 +-
docs/zh/index.html | 4 +-
docs/zh/sitemap.xml | 2 +-
docs/zh/transformers/configs.html | 84 ++++++------
docs/zh/transformers/feed_forward.html | 57 ++++----
docs/zh/transformers/index.html | 94 ++++++-------
.../zh/transformers/label_smoothing_loss.html | 10 +-
docs/zh/transformers/mha.html | 128 +++++++++---------
docs/zh/transformers/models.html | 74 +++++-----
docs/zh/transformers/positional_encoding.html | 10 +-
docs/zh/transformers/relative_mha.html | 6 +-
docs/zh/transformers/utils.html | 20 +--
.../transformers/feed_forward.zh.json | 2 +-
13 files changed, 245 insertions(+), 248 deletions(-)
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index e1d8e169..7b46859e 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -1450,7 +1450,7 @@
https://nn.labml.ai/rl/ppo/gae.html
- 2023-10-24T16:30:00+00:00
+ 2024-06-24T16:30:00+00:00
1.00
diff --git a/docs/zh/index.html b/docs/zh/index.html
index 09208a7d..322f4bf7 100644
--- a/docs/zh/index.html
+++ b/docs/zh/index.html
@@ -72,7 +72,7 @@
这是一个用 PyTorch 实现各种神经网络和相关算法的集合。每个算法的代码实现 都有详细的解释说明,且在网站 上与代码逐行对应。我们相信,这些内容将帮助您更好地理解这些算法。
-我们正在积极维护这个仓库并添加新的代码实现 以获取更新。
+我们正在积极维护这个仓库并添加新的代码实现。 以获取更新。
翻译
@@ -102,7 +102,7 @@
Primer
沙漏网络
-在一块 48GB GPU 上生成
+
diff --git a/docs/zh/sitemap.xml b/docs/zh/sitemap.xml
index e1d8e169..7b46859e 100644
--- a/docs/zh/sitemap.xml
+++ b/docs/zh/sitemap.xml
@@ -1450,7 +1450,7 @@
https://nn.labml.ai/rl/ppo/gae.html
- 2023-10-24T16:30:00+00:00
+ 2024-06-24T16:30:00+00:00
1.00
diff --git a/docs/zh/transformers/configs.html b/docs/zh/transformers/configs.html
index 91013706..d7440d3c 100644
--- a/docs/zh/transformers/configs.html
+++ b/docs/zh/transformers/configs.html
@@ -7,20 +7,20 @@
-
+
-
+
-
+
-
+
- 可配置变压器组件
+ 可配置 Transformer 组件
@@ -70,7 +70,7 @@
- 可配置变压器组件
+ 可配置的 Transformer 组件
@@ -93,8 +93,8 @@
FFN 配置
-创建在中定义的位置前馈网络feed_forward . py
- 。
+在feed_forward . py
+ 中定义了一个位置前馈网络。
@@ -118,7 +118,7 @@
-
嵌入中的要素数量
+
嵌入的特征数量
@@ -130,7 +130,7 @@
-
隐藏图层中的要素数量
+
隐藏层中的特征数量
@@ -142,7 +142,7 @@
-
辍学概率
+
Dropout 率
@@ -154,7 +154,7 @@
-
在位置前馈层激活
+
位置前馈层中的激活函数
@@ -178,7 +178,7 @@
-
第一个完全连接的层是否应该有可学习的偏差
+
第一个全连接层是否具有可学习的偏置
@@ -190,7 +190,7 @@
-
第二个全连接层是否应该有可学习的偏差
+
第二个全连接层是否具有可学习的偏置
@@ -202,7 +202,7 @@
-
栅极的全连接层是否应具有可学习的偏差
+
门控的全连接层是否具有可学习的偏置
@@ -226,7 +226,7 @@
-
激活 ReLU
+
ReLU 激活函数
max ( 0 , x )
@@ -251,9 +251,9 @@
- GELU 激活
-x Φ ( x ) 在哪里Φ ( x ) = P ( X ≤ x ) , X ∼ N ( 0 , 1 )
-它是在论文中介绍的 “高斯误差线性单位 ”。
+ GELU 激活函数
+x Φ ( x ) 其中,Φ ( x ) = P ( X ≤ x ) , X ∼ N ( 0 , 1 )
+这是在论文《 Gaussian Error Linear Units 》 中介绍的。
GLU 变体
-这些是用于FFN的封闭隐藏层的变体,如纸质 GLU变体改进变压器 中所述。我们省略了本文中指定的偏差术语。
+这些是在论文 《 GLU Variants Improve Transformer 》 中包含的各种带门控隐藏层的 ffn 变体。我们已按照论文规定省略了偏置项。
@@ -356,7 +356,7 @@
-
带有 ReLU 门的 FFN
+
带 ReLU 门的 FFN
FF N R e G LU ( x ) ( x , W 1 , V , W 2 ) = ( max ( 0 , x W 1 ) ⊗ x V ) W 2
@@ -374,7 +374,7 @@
- 带有 GELU 门的 FFN
+ 带 GELU 门的 FFN
FF N GEG LU ( x ) ( x , W 1 , V , W 2 ) = ( GELU ( x W 1 ) ⊗ x V ) W 2
@@ -392,8 +392,8 @@
- FFN 带 Swish gate
-FF N Sw i G LU ( x ) ( x , W 1 , V , W 2 ) = ( Swish 1 ( x W 1 ) ⊗ x V ) W 2 在哪里Swish β ( x ) = x σ ( β x )
+ 带 Swish 门的 FFN
+FF N Sw i G LU ( x ) ( x , W 1 , V , W 2 ) = ( Swish 1 ( x W 1 ) ⊗ x V ) W 2 其中,Swish β ( x ) = x σ ( β x )
-变压器配置
-这定义了变压器的配置。配置是使用选项函数计算的。这些是延迟加载的,因此只计算必要的模块。
+Transformer 配置
+这定义了 Transformer 的配置。这些配置是通过可选择的函数进行计算的。它们是惰性加载的,因此只有必要的模块才会被计算。
@@ -424,7 +424,7 @@
-
注意头数量
+
注意力头数量
@@ -436,7 +436,7 @@
-
变压器嵌入尺寸
+
Transformer 嵌入大小
@@ -460,7 +460,7 @@
-
辍学概率
+
Dropout 率
@@ -472,7 +472,7 @@
-
源词汇表中的标记数(用于令牌嵌入)
+
源词汇表中的 token 数量(用于 token 嵌入)
@@ -484,7 +484,7 @@
-
目标词汇表中的标记数(用于生成预测的对数)
+
目标词汇表中的 token 数量(用于生成预测的 logits )
@@ -496,7 +496,7 @@
-
编码器自我注意
+
编码器自注意力
@@ -508,7 +508,7 @@
-
解码器自我注意
+
解码器自注意力
@@ -520,7 +520,7 @@
-
解码器内存注意事项
+
解码器记忆与注意力
@@ -592,7 +592,7 @@
-
源的嵌入层
+
源数据的嵌入层
@@ -604,7 +604,7 @@
-
目标嵌入层(用于解码器)
+
目标数据的嵌入层(用于解码器)
@@ -640,7 +640,7 @@
-
多头注意
+
多头注意力
@@ -877,8 +877,8 @@
-
学习过的位置嵌入
-
使用学习的位置编码进行源嵌入
+
可学习的位置嵌入
+
使用可学习的位置编码进行嵌入
@@ -902,7 +902,7 @@
-
使用学习的位置编码进行目标嵌入
+
使用可学习的位置编码进行目标嵌入
@@ -926,8 +926,8 @@
-
没有位置嵌入
-
不带位置编码的源代码嵌入
+
无位置嵌入
+
没有位置编码的源嵌入
diff --git a/docs/zh/transformers/feed_forward.html b/docs/zh/transformers/feed_forward.html
index f663c74d..f2b86073 100644
--- a/docs/zh/transformers/feed_forward.html
+++ b/docs/zh/transformers/feed_forward.html
@@ -3,12 +3,12 @@
-
+
-
+
@@ -18,7 +18,7 @@
-
+
位置前馈网络 (FFN)
@@ -70,17 +70,16 @@
-
位置前馈网络 (FFN)
-
这是变压器中使用的按位置前馈网络的 PyTorch 实现。
-
FFN 由两个完全连接的层组成。隐藏层中的维度数d f f ,通常设置为令牌嵌入的四倍左右d m o d e l 。因此,它有时也被称为扩张和收缩网络。
-
隐藏层有一个激活,通常设置为RelU(整流线性单元)激活,max ( 0 , x )
-
也就是说,FFN 函数是、FFN ( x , W 1 , W 2 , b 1 , b 2 ) = max ( 0 , x W 1 + b 1 ) W 2 + b 2 其中W 1 W 2 、b 1 和b 2 是可学习的参数。
-
有时还会使用 GELU(高斯误差线性单位)激活来代替 RelU。x Φ ( x ) 在哪里Φ ( x ) = P ( X ≤ x ) , X ∼ N ( 0 , 1 )
+
位置前馈网络 (FFN)
+
这是 Transformer 中使用的位置前馈网络的 PyTorch 实现。
+
FFN 由两个全连接层组成。隐藏层中的维度数_%5e_0_%5e_ 通常设置为标记嵌入维度_%5e_1_%5e_ 的四倍左右。因此,它有时也被称为扩张-压缩网络。
+
隐藏层有一个激活函数,通常设置为 ReLU (Rectified Linear Unit) 激活函数,_%5e_2_%5e_
+
在此基础上, FFN 函数可以写作:_%5e_3_%5e_ 其中_%5e_4_%5e_ _%5e_5_%5e_ 、_%5e_6_%5e_ 和_%5e_7_%5e_ 是可学习的参数。
+
有时还会使用 GELU (Gaussian Error Linear Unit) 激活函数来代替 ReLU 。_%5e_8_%5e_ 其中_%5e_9_%5e_
门控线性单元
-
这是一个通用实现,支持不同的变体,包括门控线性单元 (GLU)。我们还对以下方面进行了实验:
-
+
这是一个通用实现,支持包括门控线性单元(GLU) 在内的不同变体。我们还对这些进行了实验:
+
d_model
-是令牌嵌入中的要素数量
+是标记嵌入中的特征数量
d_ff
-是 FFN 隐藏层中的要素数量
+是 FFN 隐藏层中的特征数量
dropout
-是隐藏层的丢失概率
+是隐藏层的 Dropout 率
is_gated
-指定隐藏层是否为门控
+指定了隐藏层是否为门控层
bias1
-指定第一个完全连接的层是否应该有可学习的偏差
+指定了第一个全连接层是否应该具有可学习的偏置
bias2
-指定第二个完全连接的层是否应该有可学习的偏差
+指定第二个全连接层是否应具有可学习的偏置
bias_gate
-指定门的全连接层是否应具有可学习的偏差
+指定门控的全连接层是否应具有可学习的偏置
@@ -149,7 +148,7 @@
-
第一层按权重W 1 和偏差进行参数化b 1
+
第一层由权重W 1 和偏差b 1 进行参数化
@@ -161,7 +160,7 @@
-
第一层按权重W 1 和偏差进行参数化b 1
+
第一层由权重W 1 和偏差b 1 进行参数化
@@ -173,7 +172,7 @@
-
隐藏图层丢失
+
隐藏层 Dropout
@@ -185,7 +184,7 @@
-
激活功能f
+
激活函数f
@@ -197,7 +196,7 @@
-
是否有门
+
是否存在门控
@@ -210,7 +209,7 @@
-
如果有门,则转换输入的线性层将乘以门,并通过权重V 和偏置进行参数化c
+
如果存在门控,则通过线性层将输入值与门相乘,并由权重 V 和偏置c 进行参数化
@@ -245,7 +244,7 @@
-
如果是封闭的,f ( x W 1 + b 1 ) ⊗ ( x V + b )
+
如果进行门控,f ( x W 1 + b 1 ) ⊗ ( x V + b )
@@ -271,7 +270,7 @@
-
申请退学
+
使用 Dropout
@@ -283,7 +282,7 @@
-
( f ( x W 1 + b 1 ) ⊗ ( x V + b )) W 2 + b 2 或者f ( x W 1 + b 1 ) W 2 + b 2 取决于它是否有门控
+
根据是否进行门控,返回( f ( x W 1 + b 1 ) ⊗ ( x V + b )) W 2 + b 2 或者f ( x W 1 + b 1 ) W 2 + b 2
diff --git a/docs/zh/transformers/index.html b/docs/zh/transformers/index.html
index 43f34b9e..fb9536ab 100644
--- a/docs/zh/transformers/index.html
+++ b/docs/zh/transformers/index.html
@@ -3,24 +3,24 @@
-
+
-
-
+
+
-
+
-
+
-
-
+
+
-
变压器
+
Transformers
@@ -70,50 +70,50 @@
-
变压器
-
本模块包含 PyTorch 实现和论文 Attention Is All You Need 中对原创变压器的解释,以及它的衍生品和增强功能。
-
diff --git a/docs/zh/transformers/label_smoothing_loss.html b/docs/zh/transformers/label_smoothing_loss.html
index cbafc90c..3f2b49b0 100644
--- a/docs/zh/transformers/label_smoothing_loss.html
+++ b/docs/zh/transformers/label_smoothing_loss.html
@@ -3,12 +3,12 @@
-
+
-
+
@@ -18,7 +18,7 @@
-
+
标签平滑损失
@@ -154,7 +154,7 @@
-
显示系统预期的目标分布。
+
展示系统期望的目标分布。
@@ -183,7 +183,7 @@
-
打印(预测)
+
输出(预测)
diff --git a/docs/zh/transformers/mha.html b/docs/zh/transformers/mha.html
index 4fce0daf..71798236 100644
--- a/docs/zh/transformers/mha.html
+++ b/docs/zh/transformers/mha.html
@@ -3,24 +3,24 @@
-
+
-
-
+
+
-
+
-
+
-
-
+
+
-
多头注意 (MHA)
+ 多头注意力 (MHA)
@@ -72,9 +72,7 @@
多头注意力 (MHA)
-这是 P yTorch 中论文 “注意力 就是你所需要的” 多头注意 力的教程/实现。该实现的灵感来自带注释的变形金刚 。
-以下是使用带有 MHA 的基本转换器进行 NLP 自动回归的训练代码 。
-这是一个训练简单变压器的实验实现 。
+这是论文《 Attention is All You Need 》 中多头注意力的PyTorch 教程/实现。该实现的灵感来自《带注释的变形金刚》 。
%n这是使用基础 Transformer 和 MHA 进行 NLP 自回归的训练代码 。
%n这是一个训练简单transformer的代码实现 。
-为多头注意做好准备
-该模块进行线性变换,并将向量拆分为给定数量的头部,以获得多头注意。这用于转换键 、查询 和值 向量。
+准备多头注意力
+该部分执行线性变换,并将向量分割成给定数量的头以获得多头注意力。这用于键 、查询 和值 向量。
@@ -118,7 +116,7 @@
-
线性变换的线性层
+
线性层用于线性变换/p>
@@ -130,7 +128,7 @@
-
头数
+
注意力头数
@@ -142,7 +140,7 @@
-
每个头部中以向量表示的维度数
+
每个头部中向量的维度数量
@@ -165,9 +163,9 @@
-
输入的形状[ seq_len , batch_size , d_model ]
+
输入的形状为[ seq_len , batch_size , d_model ]
或[ batch_size , d_model ]
-。我们将线性变换应用于最后一个维度,然后将其拆分为头部。
+。我们对最后一维应用线性变换,并将其分为多个头。
@@ -191,7 +189,7 @@
-
将最后一个维度拆分成头部
+
将最后一个维度分成多个头部
-多头注意模块
-这将计算给定key
-和value
-向量的缩放多头注意query
-力。
+多头注意力模块
+这将计算给出的key
+、value
+和query
+向量缩放后的多头注意力。
A tt e n t i o n ( Q , K , V ) = se q so f t ma x ( d k Q K ⊤ ) V
-简单来说,它会找到与查询匹配的键,并获取这些键的值。
-它使用查询和键的点积作为它们匹配程度的指标。在服用点产品之前so f t ma x ,先按比例缩放d k so f t ma x 之前,点积会被d k 1 。这样做是为了避免较大的点积值导致 softmax 在较大时d k 给出非常小的梯度。
-Softmax 是沿序列(或时间)的轴计算的。
+M834 80h400000v40h-400000z"> 1 。这样做是为了避免当d k 较大时,大的点积值导致 Softmax 操作输出非常小的梯度。
+Softmax 是沿序列(或时间)轴计算的。
@@ -261,12 +259,12 @@ M834 80h400000v40h-400000z">heads
-是头的数量。
+是注意力头的数量。
d_model
-是query
+是向量query
、key
和value
-向量中的要素数。
+中的特征数量。
@@ -289,7 +287,7 @@ M834 80h400000v40h-400000z">
query
+ 这些将对多头注意力的向量query
、key
和value
-向量。
+进行转换。
@@ -330,7 +328,7 @@ M834 80h400000v40h-400000z">
key
+ 在键( Key )的时间维度上进行注意力 Softmaxkey
@@ -355,7 +353,7 @@ M834 80h400000v40h-400000z">mask
-有形状[ seq_len_q , seq_len_k , batch_size ]
-,其中第一个维度是查询维度。如果查询维度等于1 它将被广播。
+的形状为[ seq_len_q , seq_len_k , batch_size ]
+,其中第一维是查询维度。如果查询维度等于1 ,则会进行广播。
@@ -443,7 +441,7 @@ M834 80h400000v40h-400000z">
[ seq_len_q , seq_len_k , batch_size , heads ]
+ 生成的掩码形状为[ seq_len_q , seq_len_k , batch_size , heads ]
@@ -471,15 +469,15 @@ M834 80h400000v40h-400000z">query
key
和value
-是存储查询 、键 和值 向量集合的张量。它们有形状[ seq_len , batch_size , d_model ]
+是存储查询 、键 和值 向量集合的张量。它们的形状为[ seq_len , batch_size , d_model ]
。
mask
-有形状[ seq_len , seq_len , batch_size ]
-并mask [ i , j , b ]
-指示是否为批量查询b
-,位置处的查询i
-有权访问位置处的键值j
-。
+的形状为[ seq_len , seq_len , batch_size ]
+,mask [ i , j , b ]
+表示批次b
+,在位置i
+处查询是否有权访问位置j
+处的键值对。
@@ -497,8 +495,8 @@ M834 80h400000v40h-400000z">query
,key
-并且value
-有形状[ seq_len , batch_size , d_model ]
+和value
+的形状为[ seq_len , batch_size , d_model ]
@@ -514,10 +512,10 @@ M834 80h400000v40h-400000z">query
+ 为注意力计算准备向量query
,key
并value
-进行注意力计算。然后这些就会有形状[ seq_len , batch_size , heads , d_k ]
+它们的形状将变为[ seq_len , batch_size , heads , d_k ]
。
@@ -532,8 +530,8 @@ M834 80h400000v40h-400000z">Q K ⊤ 。这给出了形状的张量[ seq_len , seq_len , batch_size , heads ]
-。
+ 计算注意力分数Q K ⊤ 这将得到一个形状为[ seq_len , seq_len , batch_size , heads ]
+的张量。
@@ -545,7 +543,7 @@ M834 80h400000v40h-400000z">
d k d k so f t ma x 关注按键序列维度se q so f t ma x ( d k so f t ma x se q so f t ma x ( d k se q so f t ma x ( d k se q so f t ma x ( d k
-
+
-
-
+
+
-
+
-
+
-
-
+
+
- 变压器编码器和解码器型号
+ Transformer 编码器和解码器模型
@@ -70,7 +70,7 @@
- 变压器编码器和解码器模型
+ Transformer 编码器和解码器模型
@@ -92,7 +92,7 @@
#
-
+
-嵌入令牌并添加参数化的位置编码
+嵌入 token 并添加参数化的位置编码
@@ -175,7 +175,7 @@
Transformer Layer
-This can act as an encoder layer or a decoder layer. We use pre-norm.
+这可以作为编码器层或解码器层。我们使用预正则化。
d_model
-是令牌嵌入的大小
+是 token 嵌入大小
self_attn
-是自我关注模块
+是自注意力模块
src_attn
-是源关注模块(当它在解码器中使用时)
+是注意力模块源(当它用于解码器时)
feed_forward
是前馈模块
dropout_prob
-是自我关注和 FFN 后退学的概率
+是自注意力和 FFN 后的 Dropout 率
@@ -272,7 +272,7 @@
-
通过自我关注,即关键和价值来自自我
+
通过自注意力机制运行,即键和值来自于自身
@@ -284,7 +284,7 @@
-
添加自我关注的结果
+
添加自注意力结果
@@ -296,7 +296,7 @@
-
如果提供了来源,则从关注源获取结果。这是当你有一个关注编码器输出的解码器层
时
+
如果提供了源数据,则从注意力机制中获取结果。这是指当解码器层关注编码器输出时。
@@ -320,7 +320,7 @@
-
注意源。即键和值来自源
+
关注源数据,即键和值来自源数据
@@ -332,7 +332,7 @@
-
添加来源关注结果
+
添加源关注结果
@@ -356,7 +356,7 @@
-
如果已指定,则将输入保存到前馈图层
+
如果已指定,则将输入保存到前馈层
@@ -369,7 +369,7 @@
-
通过前馈网络
+
通过前馈网络传递
-变压器编码
+Transformer 编码器
@@ -420,7 +420,7 @@
-
制作变压器层的副本
+
制作 Transformer 层的副本
@@ -432,7 +432,7 @@
-
最终归一化层
+
最终的归一化层
@@ -455,7 +455,7 @@
-
穿过每个变压器层
+
运行每个 Transformer 层
-变压器解码器
+Transformer 解码器
@@ -505,7 +505,7 @@
-
制作变压器层的副本
+
制作 Transformer 层的副本
@@ -517,7 +517,7 @@
-
最终归一化层
+
最终的归一化层
@@ -540,7 +540,7 @@
-
穿过每个变压器层
+
运行每个 Transformer 层
-发电机
-这可以预测令牌并给出其中的lof softmax。如果你正在使用,你不需要这个nn . CrossEntropyLoss
-。
+生成器
+这会预测这些标记并给出它们的 softmax 的对数。如果你使用nn . CrossEntropyLoss
+,则不需要这样做。
-组合式编码器-解码器
+组合编码器-解码器
@@ -635,7 +635,7 @@
-
从他们的代码来看,这很重要。使用 Glorot/fan_avg 初始化参数。
+
这是代码中很重要的部分。使用 Glorot/fan_avg 初始化参数。
@@ -660,7 +660,7 @@
-
通过编码器运行源码
+
通过编码器运行源代码
diff --git a/docs/zh/transformers/positional_encoding.html b/docs/zh/transformers/positional_encoding.html
index 4d14d992..c36fdbb8 100644
--- a/docs/zh/transformers/positional_encoding.html
+++ b/docs/zh/transformers/positional_encoding.html
@@ -3,12 +3,12 @@
-
+
-
+
@@ -18,7 +18,7 @@
-
+
固定位置编码
@@ -153,7 +153,7 @@
-
头寸指数
+
位置索引
@@ -213,7 +213,7 @@
-
添加批量维度
+
增加批处理维度
diff --git a/docs/zh/transformers/relative_mha.html b/docs/zh/transformers/relative_mha.html
index 3c08c520..143d6371 100644
--- a/docs/zh/transformers/relative_mha.html
+++ b/docs/zh/transformers/relative_mha.html
@@ -3,13 +3,13 @@
-
+
-
+
@@ -19,7 +19,7 @@
-
+
相对多头注意力
diff --git a/docs/zh/transformers/utils.html b/docs/zh/transformers/utils.html
index f4cdcbf4..28082e4c 100644
--- a/docs/zh/transformers/utils.html
+++ b/docs/zh/transformers/utils.html
@@ -3,24 +3,24 @@
-
+
-
-
+
+
-
+
-
+
-
-
+
+
-
变压器公用事业
+
Transformer 实用工具
@@ -70,7 +70,7 @@
-
变压器公用事业
+
Transformer 实用工具
@@ -82,7 +82,7 @@
-
后续掩码,用于掩盖未来(后续)时间步中的数据
+
用于屏蔽未来(后续)时间步数据的后续掩码
diff --git a/translate_cache/transformers/feed_forward.zh.json b/translate_cache/transformers/feed_forward.zh.json
index 038915f0..719c685d 100644
--- a/translate_cache/transformers/feed_forward.zh.json
+++ b/translate_cache/transformers/feed_forward.zh.json
@@ -1,5 +1,5 @@
{
- "
Position-wise Feed-Forward Network (FFN) \n
This is a PyTorch implementation of position-wise feedforward network used in transformer.
\n
FFN consists of two fully connected layers. Number of dimensions in the hidden layer _^_0_^_ , is generally set to around four times that of the token embedding _^_1_^_ . So it is sometime also called the expand-and-contract network.
\n
There is an activation at the hidden layer, which is usually set to ReLU (Rectified Linear Unit) activation, _^_2_^_
\n
That is, the FFN function is, _^_3_^_ where _^_4_^_ , _^_5_^_ , _^_6_^_ and _^_7_^_ are learnable parameters.
\n
Sometimes the GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. _^_8_^_ where _^_9_^_
\n
Gated Linear Units \n
This is a generic implementation that supports different variants including Gated Linear Units (GLU). We have also implemented experiments on these:
\n
\n": "
\u4f4d\u7f6e\u524d\u9988\u7f51\u7edc (FFN) \n
\u8fd9\u662f Transformer \u4e2d\u4f7f\u7528\u7684\u4f4d\u7f6e\u524d\u9988\u7f51\u7edc\u7684 PyTorch \u5b9e\u73b0\u3002
\n
FFN \u7531\u4e24\u4e2a\u5168\u8fde\u63a5\u5c42\u7ec4\u6210\u3002\u9690\u85cf\u5c42\u4e2d\u7684\u7ef4\u5ea6\u6570_%5e_0_%5e_ \u901a\u5e38\u8bbe\u7f6e\u4e3a\u6807\u8bb0\u5d4c\u5165\u7ef4\u5ea6_%5e_1_%5e_ \u7684\u56db\u500d\u5de6\u53f3\u3002\u56e0\u6b64\uff0c\u5b83\u6709\u65f6\u4e5f\u88ab\u79f0\u4e3a\u6269\u5f20-\u538b\u7f29\u7f51\u7edc\u3002
\n
\u9690\u85cf\u5c42\u6709\u4e00\u4e2a\u6fc0\u6d3b\u51fd\u6570\uff0c\u901a\u5e38\u8bbe\u7f6e\u4e3a ReLU (Rectified Linear Unit) \u6fc0\u6d3b\u51fd\u6570\uff0c_%5e_2_%5e_
\n
\u5728\u6b64\u57fa\u7840\u4e0a\uff0c FFN \u51fd\u6570\u53ef\u4ee5\u5199\u4f5c\uff1a_%5e_3_%5e_ \u5176\u4e2d_%5e_4_%5e_ _%5e_5_%5e_ \u3001_%5e_6_%5e_ \u548c_%5e_7_%5e_ \u662f\u53ef\u5b66\u4e60\u7684\u53c2\u6570\u3002
\n
\u6709\u65f6\u8fd8\u4f1a\u4f7f\u7528 GELU (Gaussian Error Linear Unit) \u6fc0\u6d3b\u51fd\u6570\u6765\u4ee3\u66ff ReLU \u3002_%5e_8_%5e_ \u5176\u4e2d_%5e_9_%5e_
\n
\u95e8\u63a7\u7ebf\u6027\u5355\u5143 \n
\u8fd9\u662f\u4e00\u4e2a\u901a\u7528\u5b9e\u73b0\uff0c\u652f\u6301\u5305\u62ec\u95e8\u63a7\u7ebf\u6027\u5355\u5143(GLU) \u5728\u5185\u7684\u4e0d\u540c\u53d8\u4f53\u3002\u6211\u4eec\u8fd8\u5bf9\u8fd9\u4e9b\u8fdb\u884c\u4e86\u5b9e\u9a8c\uff1a
\n
\n",
+ "
Position-wise Feed-Forward Network (FFN) \n
This is a PyTorch implementation of position-wise feedforward network used in transformer.
\n
FFN consists of two fully connected layers. Number of dimensions in the hidden layer _^_0_^_ , is generally set to around four times that of the token embedding _^_1_^_ . So it is sometime also called the expand-and-contract network.
\n
There is an activation at the hidden layer, which is usually set to ReLU (Rectified Linear Unit) activation, _^_2_^_
\n
That is, the FFN function is, _^_3_^_ where _^_4_^_ , _^_5_^_ , _^_6_^_ and _^_7_^_ are learnable parameters.
\n
Sometimes the GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. _^_8_^_ where _^_9_^_
\n
Gated Linear Units \n
This is a generic implementation that supports different variants including Gated Linear Units (GLU). We have also implemented experiments on these:
\n
\n": "
\u4f4d\u7f6e\u524d\u9988\u7f51\u7edc \uff08FFN\uff09 \n
\u8fd9\u662f Transformer \u4e2d\u4f7f\u7528\u7684\u4f4d\u7f6e\u524d\u9988\u7f51\u7edc\u7684 PyTorch \u5b9e\u73b0\u3002
\n
FFN \u7531\u4e24\u4e2a\u5168\u8fde\u63a5\u5c42\u7ec4\u6210\u3002\u9690\u85cf\u5c42\u4e2d\u7684\u7ef4\u5ea6\u6570_%5e_0_%5e_ \u901a\u5e38\u8bbe\u7f6e\u4e3a\u6807\u8bb0\u5d4c\u5165\u7ef4\u5ea6_%5e_1_%5e_ \u7684\u56db\u500d\u5de6\u53f3\u3002\u56e0\u6b64\uff0c\u5b83\u6709\u65f6\u4e5f\u88ab\u79f0\u4e3a\u6269\u5f20-\u538b\u7f29\u7f51\u7edc\u3002
\n
\u9690\u85cf\u5c42\u6709\u4e00\u4e2a\u6fc0\u6d3b\u51fd\u6570\uff0c\u901a\u5e38\u8bbe\u7f6e\u4e3a ReLU (Rectified Linear Unit) \u6fc0\u6d3b\u51fd\u6570\uff0c_%5e_2_%5e_
\n
\u5728\u6b64\u57fa\u7840\u4e0a\uff0c FFN \u51fd\u6570\u53ef\u4ee5\u5199\u4f5c\uff1a_%5e_3_%5e_ \u5176\u4e2d_%5e_4_%5e_ _%5e_5_%5e_ \u3001_%5e_6_%5e_ \u548c_%5e_7_%5e_ \u662f\u53ef\u5b66\u4e60\u7684\u53c2\u6570\u3002
\n
\u6709\u65f6\u8fd8\u4f1a\u4f7f\u7528 GELU (Gaussian Error Linear Unit) \u6fc0\u6d3b\u51fd\u6570\u6765\u4ee3\u66ff ReLU \u3002_%5e_8_%5e_ \u5176\u4e2d_%5e_9_%5e_
\n
\u95e8\u63a7\u7ebf\u6027\u5355\u5143 \n
\u8fd9\u662f\u4e00\u4e2a\u901a\u7528\u5b9e\u73b0\uff0c\u652f\u6301\u5305\u62ec\u95e8\u63a7\u7ebf\u6027\u5355\u5143(GLU) \u5728\u5185\u7684\u4e0d\u540c\u53d8\u4f53\u3002\u6211\u4eec\u8fd8\u5bf9\u8fd9\u4e9b\u8fdb\u884c\u4e86\u5b9e\u9a8c\uff1a
\n
\n",
"
FFN module \n": "
FFN \u6a21\u5757 \n",
"
_^_0_^_
\n": "
_^_0_^_
\n",
"
_^_0_^_ or _^_1_^_ depending on whether it is gated
\n": "
\u6839\u636e\u662f\u5426\u8fdb\u884c\u95e8\u63a7\uff0c\u8fd4\u56de_^_0_^_ \u6216\u8005_^_1_^_
\n",
From 66e92edb045c9b6b1d01b3f3d41b92fd5ef2258e Mon Sep 17 00:00:00 2001
From: Seas0
Date: Mon, 15 Jul 2024 13:06:40 +0800
Subject: [PATCH 02/16] Fix typo in Wasserstein GAN
---
labml_nn/gan/wasserstein/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/labml_nn/gan/wasserstein/__init__.py b/labml_nn/gan/wasserstein/__init__.py
index b3c52472..3c5394e0 100644
--- a/labml_nn/gan/wasserstein/__init__.py
+++ b/labml_nn/gan/wasserstein/__init__.py
@@ -26,7 +26,7 @@ marginal probabilities are $\gamma(x, y)$.
$\mathbb{E}_{(x,y) \sim \gamma} \Vert x - y \Vert$ is the earth mover distance for
a given joint distribution ($x$ and $y$ are probabilities).
-So $W(\mathbb{P}_r, \mathbb{P}g)$ is equal to the least earth mover distance for
+So $W(\mathbb{P}_r, \mathbb{P}_g)$ is equal to the least earth mover distance for
any joint distribution between the real distribution $\mathbb{P}_r$ and generated distribution $\mathbb{P}_g$.
The paper shows that Jensen-Shannon (JS) divergence and other measures for the difference between two probability
From cbc38bb26be2034f98e12e24e2d376a982fd1a71 Mon Sep 17 00:00:00 2001
From: lakshith
Date: Fri, 26 Jul 2024 09:41:13 +0530
Subject: [PATCH 03/16] GPT 2 implementation
---
docs/transformers/LoRA/GPT2.py | 239 ++++++++++++++++++++++
docs/transformers/LoRA/gpt2_state_dict.py | 35 ++++
2 files changed, 274 insertions(+)
create mode 100644 docs/transformers/LoRA/GPT2.py
create mode 100644 docs/transformers/LoRA/gpt2_state_dict.py
diff --git a/docs/transformers/LoRA/GPT2.py b/docs/transformers/LoRA/GPT2.py
new file mode 100644
index 00000000..d772874b
--- /dev/null
+++ b/docs/transformers/LoRA/GPT2.py
@@ -0,0 +1,239 @@
+import torch
+import torch.nn as nn
+from transformers import AutoTokenizer
+
+tokenizer = AutoTokenizer.from_pretrained("gpt2")
+
+# config from GPT
+config = {
+ "_name_or_path": "gpt2",
+ "activation_function": "gelu_new",
+ "architectures": [
+ "GPT2LMHeadModel"
+ ],
+ "attn_pdrop": 0.1,
+ "bos_token_id": 50256,
+ "embd_pdrop": 0.1,
+ "eos_token_id": 0,
+ "initializer_range": 0.02,
+ "layer_norm_epsilon": 1e-05,
+ "model_type": "gpt2",
+ "n_ctx": 1024,
+ "n_embd": 768,
+ "n_head": 12,
+ "n_inner": None,
+ "n_layer": 12,
+ "n_positions": 1024,
+ "reorder_and_upcast_attn": False,
+ "resid_pdrop": 0.1,
+ "scale_attn_by_inverse_layer_idx": False,
+ "scale_attn_weights": True,
+ "summary_activation": None,
+ "summary_first_dropout": 0.1,
+ "summary_proj_to_labels": True,
+ "summary_type": "cls_index",
+ "summary_use_proj": True,
+ "task_specific_params": {
+ "text-generation": {
+ "do_sample": True,
+ "max_length": 50
+ }
+ },
+ "transformers_version": "4.42.4",
+ "use_cache": True,
+ "vocab_size": 50257
+}
+
+import math
+from torch import Tensor
+
+
+# from transformers
+class Conv1D(nn.Module):
+ """
+ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
+
+ Basically works like a linear layer but the weights are transposed.
+
+ Args:
+ nf (`int`): The number of output features.
+ nx (`int`): The number of input features.
+ """
+
+ def __init__(self, nf, nx):
+ super().__init__()
+ self.nf = nf
+ self.weight = nn.Parameter(torch.empty(nx, nf))
+ self.bias = nn.Parameter(torch.zeros(nf))
+ nn.init.normal_(self.weight, std=0.02)
+
+ def forward(self, x):
+ size_out = x.size()[:-1] + (self.nf,)
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
+ x = x.view(size_out)
+ return x
+
+
+# from transformers
+class NewGELUActivation(nn.Module):
+ """
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
+
+
+class HeadFFN(nn.Module): # todo rename
+ def __init__(self, dim):
+ super().__init__()
+ self.c_fc = Conv1D(dim, config['n_embd'])
+ self.c_proj = Conv1D(config['n_embd'], dim)
+ self.act = NewGELUActivation()
+ self.dropout = nn.Dropout(config['resid_pdrop'])
+
+ def forward(self, hidden_states):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class MultiHead(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.embed_dim = config['n_embd']
+ self.num_heads = config['n_head']
+ self.head_dim = self.embed_dim // self.num_heads
+ self.split_size = self.embed_dim
+
+ self.c_att = Conv1D(config['n_embd'] * 3, config['n_embd'])
+ self.c_proj = Conv1D(config['n_embd'], config['n_embd'])
+
+ self.resid_dropout = nn.Dropout(config['resid_pdrop'])
+ self.attn_dropout = nn.Dropout(config['attn_pdrop'])
+
+ def _split_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Splits hidden_size dim into attn_head_size and num_heads
+ """
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
+ tensor = tensor.view(new_shape)
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
+
+ def forward(self, hidden_states):
+ batch_size, seq_length, _ = hidden_states.size()
+
+ query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
+ is_causal=True, # for the triangular mask
+ )
+
+ # todo why this?
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
+
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ return attn_output
+
+
+class Block(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
+ self.attn = MultiHead()
+ self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
+ self.ffn = HeadFFN(config['n_embd'] * 4)
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ hidden_states = self.pre_norm(hidden_states)
+
+ attn_output = self.attn(hidden_states)
+
+ hidden_states = attn_output + residual
+ residual = hidden_states
+ hidden_states = self.post_norm(hidden_states)
+ feed_forward_output = self.ffn(hidden_states)
+ hidden_states = feed_forward_output + residual
+
+ return hidden_states
+
+
+class GPTModel(nn.Module):
+ # todo ignored token type embeds, past key values
+ def __init__(self):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(config['vocab_size'], config['n_embd'])
+ self.position_embedding = nn.Embedding(config['n_positions'], config['n_embd'])
+
+ self.dropout = nn.Dropout(p=config['embd_pdrop'], inplace=False)
+
+ self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
+
+ self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
+
+ self.lm_head = nn.Linear(config['n_embd'], config['vocab_size'], bias=False)
+
+ def forward(self, input_ids):
+ batch_size, input_shape = input_ids.size()
+
+ token_embeddings = self.token_embedding(input_ids) # B T C
+ position_ids = torch.arange(input_shape) # T C
+ position_embeddings = self.position_embedding(position_ids) # B T C
+
+ embeddings = token_embeddings + position_embeddings
+
+ hidden_states = self.dropout(embeddings)
+
+ for block in self.blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.final_norm(hidden_states)
+
+ logits = self.lm_head(hidden_states)
+
+ return logits
+
+
+model = GPTModel()
+
+state_dict = torch.load('transformed.pth')
+
+missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+if missing_keys:
+ print(f"Missing keys: {missing_keys}")
+if unexpected_keys:
+ print(f"Unexpected keys: {unexpected_keys}")
+
+prompt = "hello how are you"
+tokenized = tokenizer(prompt, return_tensors="pt")
+
+with torch.no_grad():
+ model.eval()
+ res = model(tokenized['input_ids'])
+
+print(res)
+
+output_ids = torch.argmax(res, dim=-1)
+
+# Decode the token indices back to text
+output_text = tokenizer.decode(output_ids[0])
+
+# Print the tokens of the output
+print(output_text)
diff --git a/docs/transformers/LoRA/gpt2_state_dict.py b/docs/transformers/LoRA/gpt2_state_dict.py
new file mode 100644
index 00000000..09f27eaf
--- /dev/null
+++ b/docs/transformers/LoRA/gpt2_state_dict.py
@@ -0,0 +1,35 @@
+import torch
+from transformers import AutoModelForCausalLM
+
+model = AutoModelForCausalLM.from_pretrained("gpt2")
+
+state_dict = model.state_dict()
+
+mapping = {
+ 'transformer.wte.weight': 'token_embedding.weight',
+ 'transformer.wpe.weight': 'position_embedding.weight',
+ 'transformer.ln_f.weight': 'final_norm.weight',
+ 'transformer.ln_f.bias': 'final_norm.bias',
+ 'lm_head.weight': 'lm_head.weight'
+}
+
+for i in range(12):
+ mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
+ mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
+ mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
+ mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
+ mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
+ mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
+ mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
+ mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
+ mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
+ mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
+ mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
+ mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
+
+new_state_dict = {}
+for old_key, new_key in mapping.items():
+ if old_key in state_dict:
+ new_state_dict[new_key] = state_dict[old_key]
+
+torch.save(new_state_dict, 'transformed.pth')
From b3aedf3093272c1f658a09b5a7544e2625c5732c Mon Sep 17 00:00:00 2001
From: lakshith
Date: Sat, 27 Jul 2024 21:28:07 +0530
Subject: [PATCH 04/16] remove gelu custom impl and use pytorch impl
---
docs/transformers/LoRA/GPT2.py | 16 +---------------
1 file changed, 1 insertion(+), 15 deletions(-)
diff --git a/docs/transformers/LoRA/GPT2.py b/docs/transformers/LoRA/GPT2.py
index d772874b..ae47320a 100644
--- a/docs/transformers/LoRA/GPT2.py
+++ b/docs/transformers/LoRA/GPT2.py
@@ -44,9 +44,6 @@ config = {
"vocab_size": 50257
}
-import math
-from torch import Tensor
-
# from transformers
class Conv1D(nn.Module):
@@ -74,23 +71,12 @@ class Conv1D(nn.Module):
return x
-# from transformers
-class NewGELUActivation(nn.Module):
- """
- Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
- the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
- """
-
- def forward(self, input: Tensor) -> Tensor:
- return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
-
-
class HeadFFN(nn.Module): # todo rename
def __init__(self, dim):
super().__init__()
self.c_fc = Conv1D(dim, config['n_embd'])
self.c_proj = Conv1D(config['n_embd'], dim)
- self.act = NewGELUActivation()
+ self.act = nn.functional.gelu
self.dropout = nn.Dropout(config['resid_pdrop'])
def forward(self, hidden_states):
From 106e72605da5831251aa0e2d7b671e0a1175ba97 Mon Sep 17 00:00:00 2001
From: lakshith
Date: Sat, 27 Jul 2024 21:30:15 +0530
Subject: [PATCH 05/16] remove droput layers
---
docs/transformers/LoRA/GPT2.py | 14 ++------------
1 file changed, 2 insertions(+), 12 deletions(-)
diff --git a/docs/transformers/LoRA/GPT2.py b/docs/transformers/LoRA/GPT2.py
index ae47320a..9c7887be 100644
--- a/docs/transformers/LoRA/GPT2.py
+++ b/docs/transformers/LoRA/GPT2.py
@@ -77,13 +77,11 @@ class HeadFFN(nn.Module): # todo rename
self.c_fc = Conv1D(dim, config['n_embd'])
self.c_proj = Conv1D(config['n_embd'], dim)
self.act = nn.functional.gelu
- self.dropout = nn.Dropout(config['resid_pdrop'])
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
- hidden_states = self.dropout(hidden_states)
return hidden_states
@@ -98,9 +96,6 @@ class MultiHead(nn.Module):
self.c_att = Conv1D(config['n_embd'] * 3, config['n_embd'])
self.c_proj = Conv1D(config['n_embd'], config['n_embd'])
- self.resid_dropout = nn.Dropout(config['resid_pdrop'])
- self.attn_dropout = nn.Dropout(config['attn_pdrop'])
-
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
@@ -123,7 +118,7 @@ class MultiHead(nn.Module):
key,
value,
attn_mask=None,
- dropout_p=self.attn_dropout.p if self.training else 0.0,
+ dropout_p=0.0,
is_causal=True, # for the triangular mask
)
@@ -132,7 +127,6 @@ class MultiHead(nn.Module):
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
attn_output = self.c_proj(attn_output)
- attn_output = self.resid_dropout(attn_output)
return attn_output
@@ -168,8 +162,6 @@ class GPTModel(nn.Module):
self.token_embedding = nn.Embedding(config['vocab_size'], config['n_embd'])
self.position_embedding = nn.Embedding(config['n_positions'], config['n_embd'])
- self.dropout = nn.Dropout(p=config['embd_pdrop'], inplace=False)
-
self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
@@ -183,9 +175,7 @@ class GPTModel(nn.Module):
position_ids = torch.arange(input_shape) # T C
position_embeddings = self.position_embedding(position_ids) # B T C
- embeddings = token_embeddings + position_embeddings
-
- hidden_states = self.dropout(embeddings)
+ hidden_states = token_embeddings + position_embeddings
for block in self.blocks:
hidden_states = block(hidden_states)
From 50c3cc4eab487baa88ca974f5edb379e030a0a95 Mon Sep 17 00:00:00 2001
From: lakshith
Date: Sat, 27 Jul 2024 22:01:21 +0530
Subject: [PATCH 06/16] keep only required configs
---
docs/transformers/LoRA/GPT2.py | 31 -------------------------------
1 file changed, 31 deletions(-)
diff --git a/docs/transformers/LoRA/GPT2.py b/docs/transformers/LoRA/GPT2.py
index 9c7887be..36d9b74c 100644
--- a/docs/transformers/LoRA/GPT2.py
+++ b/docs/transformers/LoRA/GPT2.py
@@ -4,43 +4,12 @@ from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
-# config from GPT
config = {
- "_name_or_path": "gpt2",
- "activation_function": "gelu_new",
- "architectures": [
- "GPT2LMHeadModel"
- ],
- "attn_pdrop": 0.1,
- "bos_token_id": 50256,
- "embd_pdrop": 0.1,
- "eos_token_id": 0,
- "initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
- "model_type": "gpt2",
- "n_ctx": 1024,
"n_embd": 768,
"n_head": 12,
- "n_inner": None,
"n_layer": 12,
"n_positions": 1024,
- "reorder_and_upcast_attn": False,
- "resid_pdrop": 0.1,
- "scale_attn_by_inverse_layer_idx": False,
- "scale_attn_weights": True,
- "summary_activation": None,
- "summary_first_dropout": 0.1,
- "summary_proj_to_labels": True,
- "summary_type": "cls_index",
- "summary_use_proj": True,
- "task_specific_params": {
- "text-generation": {
- "do_sample": True,
- "max_length": 50
- }
- },
- "transformers_version": "4.42.4",
- "use_cache": True,
"vocab_size": 50257
}
From d1e8daa1212c6d99de09c5d258fb4a3641d9ab31 Mon Sep 17 00:00:00 2001
From: lakshith
Date: Sun, 28 Jul 2024 08:51:03 +0530
Subject: [PATCH 07/16] replace convo1D layers with linear
---
docs/transformers/LoRA/GPT2.py | 34 +++--------------------
docs/transformers/LoRA/gpt2_state_dict.py | 9 ++++++
2 files changed, 13 insertions(+), 30 deletions(-)
diff --git a/docs/transformers/LoRA/GPT2.py b/docs/transformers/LoRA/GPT2.py
index 36d9b74c..35a65273 100644
--- a/docs/transformers/LoRA/GPT2.py
+++ b/docs/transformers/LoRA/GPT2.py
@@ -14,37 +14,11 @@ config = {
}
-# from transformers
-class Conv1D(nn.Module):
- """
- 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
-
- Basically works like a linear layer but the weights are transposed.
-
- Args:
- nf (`int`): The number of output features.
- nx (`int`): The number of input features.
- """
-
- def __init__(self, nf, nx):
- super().__init__()
- self.nf = nf
- self.weight = nn.Parameter(torch.empty(nx, nf))
- self.bias = nn.Parameter(torch.zeros(nf))
- nn.init.normal_(self.weight, std=0.02)
-
- def forward(self, x):
- size_out = x.size()[:-1] + (self.nf,)
- x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
- x = x.view(size_out)
- return x
-
-
class HeadFFN(nn.Module): # todo rename
def __init__(self, dim):
super().__init__()
- self.c_fc = Conv1D(dim, config['n_embd'])
- self.c_proj = Conv1D(config['n_embd'], dim)
+ self.c_fc = nn.Linear(config['n_embd'], dim)
+ self.c_proj = nn.Linear(dim, config['n_embd'])
self.act = nn.functional.gelu
def forward(self, hidden_states):
@@ -62,8 +36,8 @@ class MultiHead(nn.Module):
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
- self.c_att = Conv1D(config['n_embd'] * 3, config['n_embd'])
- self.c_proj = Conv1D(config['n_embd'], config['n_embd'])
+ self.c_att = nn.Linear(config['n_embd'], config['n_embd'] * 3)
+ self.c_proj = nn.Linear(config['n_embd'], config['n_embd'])
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
diff --git a/docs/transformers/LoRA/gpt2_state_dict.py b/docs/transformers/LoRA/gpt2_state_dict.py
index 09f27eaf..0e8ff6be 100644
--- a/docs/transformers/LoRA/gpt2_state_dict.py
+++ b/docs/transformers/LoRA/gpt2_state_dict.py
@@ -32,4 +32,13 @@ for old_key, new_key in mapping.items():
if old_key in state_dict:
new_state_dict[new_key] = state_dict[old_key]
+# transpose weight matrices of convo 1d layers to use linear layers instead
+convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
+ [f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
+ [f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
+ [f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
+
+for layer in convo_layers:
+ new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
+
torch.save(new_state_dict, 'transformed.pth')
From 8e756f292bce5b70453575be997d4e87acd43158 Mon Sep 17 00:00:00 2001
From: Varuna Jayasiri
Date: Sun, 28 Jul 2024 11:22:27 +0530
Subject: [PATCH 08/16] lora layers
---
docs/transformers/LoRA/__init__.py | 68 ++++++++++++++++++++++++++++++
1 file changed, 68 insertions(+)
create mode 100644 docs/transformers/LoRA/__init__.py
diff --git a/docs/transformers/LoRA/__init__.py b/docs/transformers/LoRA/__init__.py
new file mode 100644
index 00000000..8955132e
--- /dev/null
+++ b/docs/transformers/LoRA/__init__.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+
+
+class Linear(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool,
+ r: int,
+ alpha: int = None):
+ if alpha is None:
+ alpha = r
+ super().__init__()
+ self.weight = nn.Parameter(torch.empty((out_features, in_features)))
+ self.weight.requires_grad = False
+
+ if bias:
+ self.bias = nn.Parameter(torch.empty(out_features))
+ self.bias.requires_grad = False
+ else:
+ self.bias = None
+
+ self.scaling = alpha / r
+ self.lora_a = nn.Parameter(torch.empty((in_features, r)))
+ self.lora_b = nn.Parameter(torch.empty((r, out_features)))
+
+ with torch.no_grad():
+ nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
+ nn.init.zeros_(self.lora_b)
+
+ def forward(self, x: torch.Tensor):
+ result = nn.functional.linear(x, self.weight, bias=self.bias)
+
+ result += (x @ self.lora_a @ self.lora_b) * self.scaling
+
+ return result
+
+
+class Embedding(nn.Module):
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ r: int,
+ alpha: int = None,
+ ):
+ if alpha is None:
+ alpha = r
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
+ self.weight.requires_grad = False
+
+ self.scaling = alpha / self.r
+ self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
+ self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))
+
+ with torch.no_grad():
+ nn.init.normal_(self.lora_a)
+ nn.init.zeros_(self.lora_b)
+
+ def forward(self, x: torch.Tensor):
+ result = nn.functional.embedding(x, self.weight)
+ result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling
+
+ return result
From c82529ce6771e3d375c44acd35777992da01a555 Mon Sep 17 00:00:00 2001
From: lakshith
Date: Mon, 29 Jul 2024 11:17:38 +0530
Subject: [PATCH 09/16] move LoRA to labml.nn
---
{docs => labml_nn}/transformers/LoRA/GPT2.py | 0
{docs => labml_nn}/transformers/LoRA/__init__.py | 0
{docs => labml_nn}/transformers/LoRA/gpt2_state_dict.py | 0
3 files changed, 0 insertions(+), 0 deletions(-)
rename {docs => labml_nn}/transformers/LoRA/GPT2.py (100%)
rename {docs => labml_nn}/transformers/LoRA/__init__.py (100%)
rename {docs => labml_nn}/transformers/LoRA/gpt2_state_dict.py (100%)
diff --git a/docs/transformers/LoRA/GPT2.py b/labml_nn/transformers/LoRA/GPT2.py
similarity index 100%
rename from docs/transformers/LoRA/GPT2.py
rename to labml_nn/transformers/LoRA/GPT2.py
diff --git a/docs/transformers/LoRA/__init__.py b/labml_nn/transformers/LoRA/__init__.py
similarity index 100%
rename from docs/transformers/LoRA/__init__.py
rename to labml_nn/transformers/LoRA/__init__.py
diff --git a/docs/transformers/LoRA/gpt2_state_dict.py b/labml_nn/transformers/LoRA/gpt2_state_dict.py
similarity index 100%
rename from docs/transformers/LoRA/gpt2_state_dict.py
rename to labml_nn/transformers/LoRA/gpt2_state_dict.py
From 23b7e2ee8e077496adf8e76b8435aff67e8d409d Mon Sep 17 00:00:00 2001
From: lakshith
Date: Mon, 29 Jul 2024 19:40:39 +0530
Subject: [PATCH 10/16] create experiment notebook and refactoring
---
labml_nn/transformers/LoRA/GPT2.py | 38 +-----
labml_nn/transformers/LoRA/experiment.ipynb | 125 ++++++++++++++++++
.../LoRA/{gpt2_state_dict.py => load_hf.py} | 0
3 files changed, 129 insertions(+), 34 deletions(-)
create mode 100644 labml_nn/transformers/LoRA/experiment.ipynb
rename labml_nn/transformers/LoRA/{gpt2_state_dict.py => load_hf.py} (100%)
diff --git a/labml_nn/transformers/LoRA/GPT2.py b/labml_nn/transformers/LoRA/GPT2.py
index 35a65273..11b92e2d 100644
--- a/labml_nn/transformers/LoRA/GPT2.py
+++ b/labml_nn/transformers/LoRA/GPT2.py
@@ -14,7 +14,7 @@ config = {
}
-class HeadFFN(nn.Module): # todo rename
+class FFN(nn.Module):
def __init__(self, dim):
super().__init__()
self.c_fc = nn.Linear(config['n_embd'], dim)
@@ -28,7 +28,7 @@ class HeadFFN(nn.Module): # todo rename
return hidden_states
-class MultiHead(nn.Module):
+class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.embed_dim = config['n_embd']
@@ -65,7 +65,6 @@ class MultiHead(nn.Module):
is_causal=True, # for the triangular mask
)
- # todo why this?
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
@@ -78,9 +77,9 @@ class Block(nn.Module):
def __init__(self):
super().__init__()
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
- self.attn = MultiHead()
+ self.attn = MultiHeadAttention()
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
- self.ffn = HeadFFN(config['n_embd'] * 4)
+ self.ffn = FFN(config['n_embd'] * 4)
def forward(self, hidden_states):
residual = hidden_states
@@ -98,7 +97,6 @@ class Block(nn.Module):
class GPTModel(nn.Module):
- # todo ignored token type embeds, past key values
def __init__(self):
super().__init__()
@@ -128,31 +126,3 @@ class GPTModel(nn.Module):
logits = self.lm_head(hidden_states)
return logits
-
-
-model = GPTModel()
-
-state_dict = torch.load('transformed.pth')
-
-missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
-if missing_keys:
- print(f"Missing keys: {missing_keys}")
-if unexpected_keys:
- print(f"Unexpected keys: {unexpected_keys}")
-
-prompt = "hello how are you"
-tokenized = tokenizer(prompt, return_tensors="pt")
-
-with torch.no_grad():
- model.eval()
- res = model(tokenized['input_ids'])
-
-print(res)
-
-output_ids = torch.argmax(res, dim=-1)
-
-# Decode the token indices back to text
-output_text = tokenizer.decode(output_ids[0])
-
-# Print the tokens of the output
-print(output_text)
diff --git a/labml_nn/transformers/LoRA/experiment.ipynb b/labml_nn/transformers/LoRA/experiment.ipynb
new file mode 100644
index 00000000..eb07a516
--- /dev/null
+++ b/labml_nn/transformers/LoRA/experiment.ipynb
@@ -0,0 +1,125 @@
+{
+ "cells": [
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-29T07:14:27.781097Z",
+ "start_time": "2024-07-29T07:14:24.819976Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
+ "import torch"
+ ],
+ "id": "cffa3ec341b4905a",
+ "outputs": [],
+ "execution_count": 1
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-29T07:14:28.183960Z",
+ "start_time": "2024-07-29T07:14:27.782683Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
+ ],
+ "id": "c2b0b7e18394ea9e",
+ "outputs": [],
+ "execution_count": 2
+ },
+ {
+ "cell_type": "code",
+ "id": "initial_id",
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2024-07-29T07:14:29.840925Z",
+ "start_time": "2024-07-29T07:14:28.185080Z"
+ }
+ },
+ "source": [
+ "model = GPTModel()\n",
+ "\n",
+ "state_dict = torch.load('transformed.pth')\n",
+ "\n",
+ "missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n",
+ "if missing_keys:\n",
+ " print(f\"Missing keys: {missing_keys}\")\n",
+ "if unexpected_keys:\n",
+ " print(f\"Unexpected keys: {unexpected_keys}\")"
+ ],
+ "outputs": [],
+ "execution_count": 3
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-29T07:22:30.408855Z",
+ "start_time": "2024-07-29T07:22:30.168376Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "prompt = \"hello how are you\"\n",
+ "tokenized = tokenizer(prompt, return_tensors=\"pt\")\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " model.eval()\n",
+ " res = model(tokenized['input_ids'])\n",
+ "\n",
+ "output_ids = torch.argmax(res, dim=-1)\n",
+ "for id in output_ids[0]:\n",
+ " print(tokenizer.decode(id))"
+ ],
+ "id": "f4f7826ec3729b66",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ ",\n",
+ " to\n",
+ " you\n",
+ " doing\n"
+ ]
+ }
+ ],
+ "execution_count": 17
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null,
+ "source": "",
+ "id": "c12776360008a974"
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python (ml)",
+ "language": "python",
+ "name": "ml"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/labml_nn/transformers/LoRA/gpt2_state_dict.py b/labml_nn/transformers/LoRA/load_hf.py
similarity index 100%
rename from labml_nn/transformers/LoRA/gpt2_state_dict.py
rename to labml_nn/transformers/LoRA/load_hf.py
From 0f2a9be6d27023eb4c33130cc10d06d5c71b8f7b Mon Sep 17 00:00:00 2001
From: lakshith
Date: Mon, 29 Jul 2024 23:01:06 +0530
Subject: [PATCH 11/16] training loop
---
labml_nn/transformers/LoRA/train.ipynb | 162 +++++++++++++++++++++++++
1 file changed, 162 insertions(+)
create mode 100644 labml_nn/transformers/LoRA/train.ipynb
diff --git a/labml_nn/transformers/LoRA/train.ipynb b/labml_nn/transformers/LoRA/train.ipynb
new file mode 100644
index 00000000..342ba78d
--- /dev/null
+++ b/labml_nn/transformers/LoRA/train.ipynb
@@ -0,0 +1,162 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "id": "initial_id",
+ "metadata": {
+ "collapsed": true
+ },
+ "source": "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "source": [
+ "with open('input.txt', 'r', encoding='utf-8') as f:\n",
+ " text = f.read()"
+ ],
+ "id": "3b1e507015ba6b81",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "source": [
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
+ "\n",
+ "tokens = tokenizer.encode(text, add_special_tokens=False)"
+ ],
+ "id": "ac8e51ae5bbfcae7",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "source": [
+ "context_length = 10\n",
+ "batch_size = 64"
+ ],
+ "id": "aeefcdf813e427e",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "source": [
+ "num_batches = len(tokens) // (batch_size * context_length)\n",
+ "tokens = tokens[:num_batches * batch_size * context_length]"
+ ],
+ "id": "a384b42274f008a2",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "\n",
+ "input_ids = torch.tensor(tokens).view(-1, context_length)"
+ ],
+ "id": "5c4cc78ac1a02c1d",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "source": [
+ "from torch.utils.data import DataLoader, TensorDataset\n",
+ "from torch.optim import Adam\n",
+ "print(input_ids.shape)\n",
+ "dataset = TensorDataset(input_ids)\n",
+ "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
+ ],
+ "id": "7037fd75e2161382",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "source": [
+ "from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
+ "\n",
+ "model = GPTModel()"
+ ],
+ "id": "a98b7baa064b8494",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "source": [
+ "optimizer = Adam(model.parameters(), lr=5e-5)\n",
+ "criterion = torch.nn.CrossEntropyLoss()\n",
+ "\n",
+ "model.eval()\n",
+ "epochs = 3\n",
+ "for epoch in range(epochs):\n",
+ " for batch in dataloader:\n",
+ " inputs = batch[0]\n",
+ " labels = inputs.clone()\n",
+ " \n",
+ " outputs = model(inputs)\n",
+ " \n",
+ " shift_logits = outputs[..., :-1, :]\n",
+ " shift_labels = labels[..., 1:]\n",
+ " \n",
+ " loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
+ " \n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n",
+ " break\n",
+ "\n",
+ "print(\"Training complete.\")"
+ ],
+ "id": "e2f5076894770740",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "source": "",
+ "id": "da2d4023002648dc",
+ "outputs": [],
+ "execution_count": null
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python (ml)",
+ "language": "python",
+ "name": "ml"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
From 77d00f089b56870ff9d1240c73dd433767cd366a Mon Sep 17 00:00:00 2001
From: lakshith
Date: Wed, 31 Jul 2024 18:29:24 +0530
Subject: [PATCH 12/16] Add LoRA to GPT2
---
labml_nn/transformers/LoRA/GPT2.py | 20 +-
labml_nn/transformers/LoRA/__init__.py | 2 +-
labml_nn/transformers/LoRA/experiment.ipynb | 55 ++--
labml_nn/transformers/LoRA/train.ipynb | 272 +++++++++++++++-----
4 files changed, 260 insertions(+), 89 deletions(-)
diff --git a/labml_nn/transformers/LoRA/GPT2.py b/labml_nn/transformers/LoRA/GPT2.py
index 11b92e2d..a7a59342 100644
--- a/labml_nn/transformers/LoRA/GPT2.py
+++ b/labml_nn/transformers/LoRA/GPT2.py
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer
+from labml_nn.transformers.LoRA import Linear, Embedding
tokenizer = AutoTokenizer.from_pretrained("gpt2")
@@ -10,15 +11,16 @@ config = {
"n_head": 12,
"n_layer": 12,
"n_positions": 1024,
- "vocab_size": 50257
+ "vocab_size": 50257,
+ "device": "cuda"
}
class FFN(nn.Module):
def __init__(self, dim):
super().__init__()
- self.c_fc = nn.Linear(config['n_embd'], dim)
- self.c_proj = nn.Linear(dim, config['n_embd'])
+ self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True)
+ self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True)
self.act = nn.functional.gelu
def forward(self, hidden_states):
@@ -36,8 +38,8 @@ class MultiHeadAttention(nn.Module):
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
- self.c_att = nn.Linear(config['n_embd'], config['n_embd'] * 3)
- self.c_proj = nn.Linear(config['n_embd'], config['n_embd'])
+ self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True)
+ self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True)
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
@@ -100,20 +102,20 @@ class GPTModel(nn.Module):
def __init__(self):
super().__init__()
- self.token_embedding = nn.Embedding(config['vocab_size'], config['n_embd'])
- self.position_embedding = nn.Embedding(config['n_positions'], config['n_embd'])
+ self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32)
+ self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32)
self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
- self.lm_head = nn.Linear(config['n_embd'], config['vocab_size'], bias=False)
+ self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False)
def forward(self, input_ids):
batch_size, input_shape = input_ids.size()
token_embeddings = self.token_embedding(input_ids) # B T C
- position_ids = torch.arange(input_shape) # T C
+ position_ids = torch.arange(input_shape, device=config['device']) # T C
position_embeddings = self.position_embedding(position_ids) # B T C
hidden_states = token_embeddings + position_embeddings
diff --git a/labml_nn/transformers/LoRA/__init__.py b/labml_nn/transformers/LoRA/__init__.py
index 8955132e..302a4bf9 100644
--- a/labml_nn/transformers/LoRA/__init__.py
+++ b/labml_nn/transformers/LoRA/__init__.py
@@ -53,7 +53,7 @@ class Embedding(nn.Module):
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
self.weight.requires_grad = False
- self.scaling = alpha / self.r
+ self.scaling = alpha / r
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))
diff --git a/labml_nn/transformers/LoRA/experiment.ipynb b/labml_nn/transformers/LoRA/experiment.ipynb
index eb07a516..7070991d 100644
--- a/labml_nn/transformers/LoRA/experiment.ipynb
+++ b/labml_nn/transformers/LoRA/experiment.ipynb
@@ -3,8 +3,8 @@
{
"metadata": {
"ExecuteTime": {
- "end_time": "2024-07-29T07:14:27.781097Z",
- "start_time": "2024-07-29T07:14:24.819976Z"
+ "end_time": "2024-07-31T12:22:57.496965Z",
+ "start_time": "2024-07-31T12:22:55.151730Z"
}
},
"cell_type": "code",
@@ -19,8 +19,8 @@
{
"metadata": {
"ExecuteTime": {
- "end_time": "2024-07-29T07:14:28.183960Z",
- "start_time": "2024-07-29T07:14:27.782683Z"
+ "end_time": "2024-07-31T12:22:57.986397Z",
+ "start_time": "2024-07-31T12:22:57.498305Z"
}
},
"cell_type": "code",
@@ -39,8 +39,8 @@
"metadata": {
"collapsed": true,
"ExecuteTime": {
- "end_time": "2024-07-29T07:14:29.840925Z",
- "start_time": "2024-07-29T07:14:28.185080Z"
+ "end_time": "2024-07-31T12:22:58.562136Z",
+ "start_time": "2024-07-31T12:22:57.987296Z"
}
},
"source": [
@@ -54,20 +54,38 @@
"if unexpected_keys:\n",
" print(f\"Unexpected keys: {unexpected_keys}\")"
],
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/tmp/ipykernel_7130/2581223434.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
+ " state_dict = torch.load('transformed.pth')\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Missing keys: ['token_embedding.lora_a', 'token_embedding.lora_b', 'position_embedding.lora_a', 'position_embedding.lora_b', 'blocks.0.attn.c_att.lora_a', 'blocks.0.attn.c_att.lora_b', 'blocks.0.attn.c_proj.lora_a', 'blocks.0.attn.c_proj.lora_b', 'blocks.0.ffn.c_fc.lora_a', 'blocks.0.ffn.c_fc.lora_b', 'blocks.0.ffn.c_proj.lora_a', 'blocks.0.ffn.c_proj.lora_b', 'blocks.1.attn.c_att.lora_a', 'blocks.1.attn.c_att.lora_b', 'blocks.1.attn.c_proj.lora_a', 'blocks.1.attn.c_proj.lora_b', 'blocks.1.ffn.c_fc.lora_a', 'blocks.1.ffn.c_fc.lora_b', 'blocks.1.ffn.c_proj.lora_a', 'blocks.1.ffn.c_proj.lora_b', 'blocks.2.attn.c_att.lora_a', 'blocks.2.attn.c_att.lora_b', 'blocks.2.attn.c_proj.lora_a', 'blocks.2.attn.c_proj.lora_b', 'blocks.2.ffn.c_fc.lora_a', 'blocks.2.ffn.c_fc.lora_b', 'blocks.2.ffn.c_proj.lora_a', 'blocks.2.ffn.c_proj.lora_b', 'blocks.3.attn.c_att.lora_a', 'blocks.3.attn.c_att.lora_b', 'blocks.3.attn.c_proj.lora_a', 'blocks.3.attn.c_proj.lora_b', 'blocks.3.ffn.c_fc.lora_a', 'blocks.3.ffn.c_fc.lora_b', 'blocks.3.ffn.c_proj.lora_a', 'blocks.3.ffn.c_proj.lora_b', 'blocks.4.attn.c_att.lora_a', 'blocks.4.attn.c_att.lora_b', 'blocks.4.attn.c_proj.lora_a', 'blocks.4.attn.c_proj.lora_b', 'blocks.4.ffn.c_fc.lora_a', 'blocks.4.ffn.c_fc.lora_b', 'blocks.4.ffn.c_proj.lora_a', 'blocks.4.ffn.c_proj.lora_b', 'blocks.5.attn.c_att.lora_a', 'blocks.5.attn.c_att.lora_b', 'blocks.5.attn.c_proj.lora_a', 'blocks.5.attn.c_proj.lora_b', 'blocks.5.ffn.c_fc.lora_a', 'blocks.5.ffn.c_fc.lora_b', 'blocks.5.ffn.c_proj.lora_a', 'blocks.5.ffn.c_proj.lora_b', 'blocks.6.attn.c_att.lora_a', 'blocks.6.attn.c_att.lora_b', 'blocks.6.attn.c_proj.lora_a', 'blocks.6.attn.c_proj.lora_b', 'blocks.6.ffn.c_fc.lora_a', 'blocks.6.ffn.c_fc.lora_b', 'blocks.6.ffn.c_proj.lora_a', 'blocks.6.ffn.c_proj.lora_b', 'blocks.7.attn.c_att.lora_a', 'blocks.7.attn.c_att.lora_b', 'blocks.7.attn.c_proj.lora_a', 'blocks.7.attn.c_proj.lora_b', 'blocks.7.ffn.c_fc.lora_a', 'blocks.7.ffn.c_fc.lora_b', 'blocks.7.ffn.c_proj.lora_a', 'blocks.7.ffn.c_proj.lora_b', 'blocks.8.attn.c_att.lora_a', 'blocks.8.attn.c_att.lora_b', 'blocks.8.attn.c_proj.lora_a', 'blocks.8.attn.c_proj.lora_b', 'blocks.8.ffn.c_fc.lora_a', 'blocks.8.ffn.c_fc.lora_b', 'blocks.8.ffn.c_proj.lora_a', 'blocks.8.ffn.c_proj.lora_b', 'blocks.9.attn.c_att.lora_a', 'blocks.9.attn.c_att.lora_b', 'blocks.9.attn.c_proj.lora_a', 'blocks.9.attn.c_proj.lora_b', 'blocks.9.ffn.c_fc.lora_a', 'blocks.9.ffn.c_fc.lora_b', 'blocks.9.ffn.c_proj.lora_a', 'blocks.9.ffn.c_proj.lora_b', 'blocks.10.attn.c_att.lora_a', 'blocks.10.attn.c_att.lora_b', 'blocks.10.attn.c_proj.lora_a', 'blocks.10.attn.c_proj.lora_b', 'blocks.10.ffn.c_fc.lora_a', 'blocks.10.ffn.c_fc.lora_b', 'blocks.10.ffn.c_proj.lora_a', 'blocks.10.ffn.c_proj.lora_b', 'blocks.11.attn.c_att.lora_a', 'blocks.11.attn.c_att.lora_b', 'blocks.11.attn.c_proj.lora_a', 'blocks.11.attn.c_proj.lora_b', 'blocks.11.ffn.c_fc.lora_a', 'blocks.11.ffn.c_fc.lora_b', 'blocks.11.ffn.c_proj.lora_a', 'blocks.11.ffn.c_proj.lora_b', 'lm_head.lora_a', 'lm_head.lora_b']\n"
+ ]
+ }
+ ],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
- "end_time": "2024-07-29T07:22:30.408855Z",
- "start_time": "2024-07-29T07:22:30.168376Z"
+ "end_time": "2024-07-31T12:23:00.447976Z",
+ "start_time": "2024-07-31T12:22:58.566527Z"
}
},
"cell_type": "code",
"source": [
"prompt = \"hello how are you\"\n",
"tokenized = tokenizer(prompt, return_tensors=\"pt\")\n",
+ "tokenized['input_ids'] = tokenized['input_ids'].to('cuda')\n",
+ "model = model.to('cuda')\n",
"\n",
"with torch.no_grad():\n",
" model.eval()\n",
@@ -90,22 +108,27 @@
]
}
],
- "execution_count": 17
+ "execution_count": 4
},
{
- "metadata": {},
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:23:00.452060Z",
+ "start_time": "2024-07-31T12:23:00.448904Z"
+ }
+ },
"cell_type": "code",
- "outputs": [],
- "execution_count": null,
"source": "",
- "id": "c12776360008a974"
+ "id": "c12776360008a974",
+ "outputs": [],
+ "execution_count": 4
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python (ml)",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
- "name": "ml"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
diff --git a/labml_nn/transformers/LoRA/train.ipynb b/labml_nn/transformers/LoRA/train.ipynb
index 342ba78d..cd70bfb3 100644
--- a/labml_nn/transformers/LoRA/train.ipynb
+++ b/labml_nn/transformers/LoRA/train.ipynb
@@ -4,26 +4,44 @@
"cell_type": "code",
"id": "initial_id",
"metadata": {
- "collapsed": true
+ "collapsed": true,
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:37.296030Z",
+ "start_time": "2024-07-31T12:57:37.292368Z"
+ }
},
- "source": "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
+ "source": "# !wget https://raw.github/zusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
"outputs": [],
- "execution_count": null
+ "execution_count": 1
},
{
- "metadata": {},
"cell_type": "code",
+ "id": "3b1e507015ba6b81",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:37.317651Z",
+ "start_time": "2024-07-31T12:57:37.313808Z"
+ }
+ },
"source": [
"with open('input.txt', 'r', encoding='utf-8') as f:\n",
" text = f.read()"
],
- "id": "3b1e507015ba6b81",
"outputs": [],
- "execution_count": null
+ "execution_count": 2
},
{
- "metadata": {},
"cell_type": "code",
+ "id": "ac8e51ae5bbfcae7",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:40.488939Z",
+ "start_time": "2024-07-31T12:57:37.319486Z"
+ }
+ },
"source": [
"from transformers import AutoTokenizer\n",
"\n",
@@ -31,130 +49,258 @@
"\n",
"tokens = tokenizer.encode(text, add_special_tokens=False)"
],
- "id": "ac8e51ae5bbfcae7",
- "outputs": [],
- "execution_count": null
- },
- {
- "metadata": {},
- "cell_type": "code",
- "source": [
- "context_length = 10\n",
- "batch_size = 64"
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Token indices sequence length is longer than the specified maximum sequence length for this model (338025 > 1024). Running this sequence through the model will result in indexing errors\n"
+ ]
+ }
],
- "id": "aeefcdf813e427e",
- "outputs": [],
- "execution_count": null
+ "execution_count": 3
},
{
- "metadata": {},
"cell_type": "code",
+ "id": "aeefcdf813e427e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:40.495510Z",
+ "start_time": "2024-07-31T12:57:40.490341Z"
+ }
+ },
+ "source": [
+ "context_length = 512\n",
+ "batch_size = 2"
+ ],
+ "outputs": [],
+ "execution_count": 4
+ },
+ {
+ "cell_type": "code",
+ "id": "a384b42274f008a2",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:40.522050Z",
+ "start_time": "2024-07-31T12:57:40.496842Z"
+ }
+ },
"source": [
"num_batches = len(tokens) // (batch_size * context_length)\n",
"tokens = tokens[:num_batches * batch_size * context_length]"
],
- "id": "a384b42274f008a2",
"outputs": [],
- "execution_count": null
+ "execution_count": 5
},
{
- "metadata": {},
"cell_type": "code",
+ "id": "5c4cc78ac1a02c1d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:40.592272Z",
+ "start_time": "2024-07-31T12:57:40.524063Z"
+ }
+ },
"source": [
"import torch\n",
"\n",
"input_ids = torch.tensor(tokens).view(-1, context_length)"
],
- "id": "5c4cc78ac1a02c1d",
"outputs": [],
- "execution_count": null
+ "execution_count": 6
},
{
- "metadata": {},
"cell_type": "code",
+ "id": "7037fd75e2161382",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:40.601199Z",
+ "start_time": "2024-07-31T12:57:40.593250Z"
+ }
+ },
"source": [
"from torch.utils.data import DataLoader, TensorDataset\n",
"from torch.optim import Adam\n",
- "print(input_ids.shape)\n",
+ "from torch.utils.data import random_split\n",
+ "\n",
"dataset = TensorDataset(input_ids)\n",
- "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
+ "\n",
+ "train_ratio = 0.8\n",
+ "test_ratio = 0.2\n",
+ "\n",
+ "train_size = int(train_ratio * len(dataset))\n",
+ "test_size = len(dataset) - train_size\n",
+ "\n",
+ "train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
+ "\n",
+ "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
+ "test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
],
- "id": "7037fd75e2161382",
"outputs": [],
- "execution_count": null
+ "execution_count": 7
},
{
- "metadata": {},
"cell_type": "code",
+ "id": "a98b7baa064b8494",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:41.577878Z",
+ "start_time": "2024-07-31T12:57:40.602187Z"
+ }
+ },
"source": [
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
"\n",
- "model = GPTModel()"
+ "model = GPTModel()\n",
+ "state_dict = torch.load('transformed.pth', weights_only=True)\n",
+ "\n",
+ "_ = model.load_state_dict(state_dict, strict=False)"
],
- "id": "a98b7baa064b8494",
"outputs": [],
- "execution_count": null
+ "execution_count": 8
},
{
- "metadata": {},
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:43.098187Z",
+ "start_time": "2024-07-31T12:57:41.578713Z"
+ }
+ },
"cell_type": "code",
"source": [
+ "device = \"cuda\"\n",
+ "model = model.to(device=\"cuda\")"
+ ],
+ "id": "2e0fa8b3082df716",
+ "outputs": [],
+ "execution_count": 9
+ },
+ {
+ "cell_type": "code",
+ "id": "e2f5076894770740",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:57.044755Z",
+ "start_time": "2024-07-31T12:57:43.099050Z"
+ }
+ },
+ "source": [
+ "from labml import tracker, experiment\n",
+ "\n",
"optimizer = Adam(model.parameters(), lr=5e-5)\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"\n",
- "model.eval()\n",
+ "model.train()\n",
"epochs = 3\n",
- "for epoch in range(epochs):\n",
- " for batch in dataloader:\n",
- " inputs = batch[0]\n",
- " labels = inputs.clone()\n",
- " \n",
- " outputs = model(inputs)\n",
- " \n",
- " shift_logits = outputs[..., :-1, :]\n",
- " shift_labels = labels[..., 1:]\n",
- " \n",
- " loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
- " \n",
- " optimizer.zero_grad()\n",
- " loss.backward()\n",
- " optimizer.step()\n",
+ "step = 0\n",
"\n",
+ "with experiment.record(name='LoRA.GPT2', app_url='http://localhost:5005/api/v1/track'):\n",
+ " for epoch in range(epochs):\n",
+ " for batch in train_dataloader:\n",
+ " inputs = batch[0]\n",
+ " inputs = inputs.to(device)\n",
+ " labels = inputs.clone()\n",
+ " \n",
+ " outputs = model(inputs)\n",
+ " \n",
+ " shift_logits = outputs[..., :-1, :]\n",
+ " shift_labels = labels[..., 1:]\n",
+ " \n",
+ " loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
+ " \n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " \n",
+ " tracker.save(step, {'loss': loss})\n",
+ " step += 1\n",
" print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n",
- " break\n",
+ " \n",
+ " test_loss = 0\n",
+ " for batch in test_dataloader:\n",
+ " inputs = batch[0]\n",
+ " inputs = inputs.to(device)\n",
+ " labels = inputs.clone()\n",
+ " \n",
+ " outputs = model(inputs)\n",
+ " \n",
+ " shift_logits = outputs[..., :-1, :]\n",
+ " shift_labels = labels[..., 1:]\n",
+ " \n",
+ " loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
+ " \n",
+ " test_loss += loss.item()\n",
+ " test_loss /= len(test_dataloader)\n",
+ " tracker.save(step, {'test_loss': test_loss})\n",
+ " \n",
"\n",
"print(\"Training complete.\")"
],
- "id": "e2f5076894770740",
- "outputs": [],
- "execution_count": null
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ "\n",
+ "LoRA.GPT2 : 7a14822c4f3c11efad8354ef33f17c7c \n",
+ "\t[dirty]: \"training loop\" \n",
+ "Monitor experiment at http://localhost:5005/run/7a14822c4f3c11efad8354ef33f17c7c \n",
+ "Still updating labml server, please wait for it to complete... "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
+ "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
+ "Cell \u001B[0;32mIn[10], line 25\u001B[0m\n\u001B[1;32m 22\u001B[0m loss \u001B[38;5;241m=\u001B[39m criterion(shift_logits\u001B[38;5;241m.\u001B[39mreshape(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m, shift_logits\u001B[38;5;241m.\u001B[39msize(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)), shift_labels\u001B[38;5;241m.\u001B[39mreshape(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m))\n\u001B[1;32m 24\u001B[0m optimizer\u001B[38;5;241m.\u001B[39mzero_grad()\n\u001B[0;32m---> 25\u001B[0m loss\u001B[38;5;241m.\u001B[39mbackward()\n\u001B[1;32m 26\u001B[0m optimizer\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m 28\u001B[0m tracker\u001B[38;5;241m.\u001B[39msave(step, {\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m'\u001B[39m: loss})\n",
+ "File \u001B[0;32m~/miniconda3/lib/python3.12/site-packages/torch/_tensor.py:521\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 511\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m 512\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m 513\u001B[0m Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m 514\u001B[0m (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 519\u001B[0m inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m 520\u001B[0m )\n\u001B[0;32m--> 521\u001B[0m torch\u001B[38;5;241m.\u001B[39mautograd\u001B[38;5;241m.\u001B[39mbackward(\n\u001B[1;32m 522\u001B[0m \u001B[38;5;28mself\u001B[39m, gradient, retain_graph, create_graph, inputs\u001B[38;5;241m=\u001B[39minputs\n\u001B[1;32m 523\u001B[0m )\n",
+ "File \u001B[0;32m~/miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:289\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m 284\u001B[0m retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m 286\u001B[0m \u001B[38;5;66;03m# The reason we repeat the same comment below is that\u001B[39;00m\n\u001B[1;32m 287\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m 288\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 289\u001B[0m _engine_run_backward(\n\u001B[1;32m 290\u001B[0m tensors,\n\u001B[1;32m 291\u001B[0m grad_tensors_,\n\u001B[1;32m 292\u001B[0m retain_graph,\n\u001B[1;32m 293\u001B[0m create_graph,\n\u001B[1;32m 294\u001B[0m inputs,\n\u001B[1;32m 295\u001B[0m allow_unreachable\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m,\n\u001B[1;32m 296\u001B[0m accumulate_grad\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m,\n\u001B[1;32m 297\u001B[0m )\n",
+ "File \u001B[0;32m~/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:768\u001B[0m, in \u001B[0;36m_engine_run_backward\u001B[0;34m(t_outputs, *args, **kwargs)\u001B[0m\n\u001B[1;32m 766\u001B[0m unregister_hooks \u001B[38;5;241m=\u001B[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001B[1;32m 767\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 768\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m Variable\u001B[38;5;241m.\u001B[39m_execution_engine\u001B[38;5;241m.\u001B[39mrun_backward( \u001B[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001B[39;00m\n\u001B[1;32m 769\u001B[0m t_outputs, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs\n\u001B[1;32m 770\u001B[0m ) \u001B[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001B[39;00m\n\u001B[1;32m 771\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[1;32m 772\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m attach_logging_hooks:\n",
+ "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
+ ]
+ }
+ ],
+ "execution_count": 10
},
{
- "metadata": {},
"cell_type": "code",
- "source": "",
"id": "da2d4023002648dc",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-31T12:57:57.046254Z",
+ "start_time": "2024-07-31T12:57:57.045954Z"
+ }
+ },
+ "source": [],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python (ml)",
+ "display_name": "base",
"language": "python",
- "name": "ml"
+ "name": "base"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
- "version": 2
+ "version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.6"
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
}
},
"nbformat": 4,
From bc32b507ea06a51390ddc3d15dc5bdbf19f10986 Mon Sep 17 00:00:00 2001
From: lakshith
Date: Wed, 31 Jul 2024 20:39:46 +0530
Subject: [PATCH 13/16] clear notebook outputs
---
labml_nn/transformers/LoRA/experiment.ipynb | 75 ++---------
labml_nn/transformers/LoRA/train.ipynb | 137 ++++----------------
2 files changed, 34 insertions(+), 178 deletions(-)
diff --git a/labml_nn/transformers/LoRA/experiment.ipynb b/labml_nn/transformers/LoRA/experiment.ipynb
index 7070991d..f0ae1c84 100644
--- a/labml_nn/transformers/LoRA/experiment.ipynb
+++ b/labml_nn/transformers/LoRA/experiment.ipynb
@@ -1,12 +1,7 @@
{
"cells": [
{
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:22:57.496965Z",
- "start_time": "2024-07-31T12:22:55.151730Z"
- }
- },
+ "metadata": {},
"cell_type": "code",
"source": [
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
@@ -14,15 +9,10 @@
],
"id": "cffa3ec341b4905a",
"outputs": [],
- "execution_count": 1
+ "execution_count": null
},
{
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:22:57.986397Z",
- "start_time": "2024-07-31T12:22:57.498305Z"
- }
- },
+ "metadata": {},
"cell_type": "code",
"source": [
"from transformers import AutoTokenizer\n",
@@ -31,17 +21,13 @@
],
"id": "c2b0b7e18394ea9e",
"outputs": [],
- "execution_count": 2
+ "execution_count": null
},
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
- "collapsed": true,
- "ExecuteTime": {
- "end_time": "2024-07-31T12:22:58.562136Z",
- "start_time": "2024-07-31T12:22:57.987296Z"
- }
+ "collapsed": true
},
"source": [
"model = GPTModel()\n",
@@ -54,32 +40,11 @@
"if unexpected_keys:\n",
" print(f\"Unexpected keys: {unexpected_keys}\")"
],
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/tmp/ipykernel_7130/2581223434.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
- " state_dict = torch.load('transformed.pth')\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Missing keys: ['token_embedding.lora_a', 'token_embedding.lora_b', 'position_embedding.lora_a', 'position_embedding.lora_b', 'blocks.0.attn.c_att.lora_a', 'blocks.0.attn.c_att.lora_b', 'blocks.0.attn.c_proj.lora_a', 'blocks.0.attn.c_proj.lora_b', 'blocks.0.ffn.c_fc.lora_a', 'blocks.0.ffn.c_fc.lora_b', 'blocks.0.ffn.c_proj.lora_a', 'blocks.0.ffn.c_proj.lora_b', 'blocks.1.attn.c_att.lora_a', 'blocks.1.attn.c_att.lora_b', 'blocks.1.attn.c_proj.lora_a', 'blocks.1.attn.c_proj.lora_b', 'blocks.1.ffn.c_fc.lora_a', 'blocks.1.ffn.c_fc.lora_b', 'blocks.1.ffn.c_proj.lora_a', 'blocks.1.ffn.c_proj.lora_b', 'blocks.2.attn.c_att.lora_a', 'blocks.2.attn.c_att.lora_b', 'blocks.2.attn.c_proj.lora_a', 'blocks.2.attn.c_proj.lora_b', 'blocks.2.ffn.c_fc.lora_a', 'blocks.2.ffn.c_fc.lora_b', 'blocks.2.ffn.c_proj.lora_a', 'blocks.2.ffn.c_proj.lora_b', 'blocks.3.attn.c_att.lora_a', 'blocks.3.attn.c_att.lora_b', 'blocks.3.attn.c_proj.lora_a', 'blocks.3.attn.c_proj.lora_b', 'blocks.3.ffn.c_fc.lora_a', 'blocks.3.ffn.c_fc.lora_b', 'blocks.3.ffn.c_proj.lora_a', 'blocks.3.ffn.c_proj.lora_b', 'blocks.4.attn.c_att.lora_a', 'blocks.4.attn.c_att.lora_b', 'blocks.4.attn.c_proj.lora_a', 'blocks.4.attn.c_proj.lora_b', 'blocks.4.ffn.c_fc.lora_a', 'blocks.4.ffn.c_fc.lora_b', 'blocks.4.ffn.c_proj.lora_a', 'blocks.4.ffn.c_proj.lora_b', 'blocks.5.attn.c_att.lora_a', 'blocks.5.attn.c_att.lora_b', 'blocks.5.attn.c_proj.lora_a', 'blocks.5.attn.c_proj.lora_b', 'blocks.5.ffn.c_fc.lora_a', 'blocks.5.ffn.c_fc.lora_b', 'blocks.5.ffn.c_proj.lora_a', 'blocks.5.ffn.c_proj.lora_b', 'blocks.6.attn.c_att.lora_a', 'blocks.6.attn.c_att.lora_b', 'blocks.6.attn.c_proj.lora_a', 'blocks.6.attn.c_proj.lora_b', 'blocks.6.ffn.c_fc.lora_a', 'blocks.6.ffn.c_fc.lora_b', 'blocks.6.ffn.c_proj.lora_a', 'blocks.6.ffn.c_proj.lora_b', 'blocks.7.attn.c_att.lora_a', 'blocks.7.attn.c_att.lora_b', 'blocks.7.attn.c_proj.lora_a', 'blocks.7.attn.c_proj.lora_b', 'blocks.7.ffn.c_fc.lora_a', 'blocks.7.ffn.c_fc.lora_b', 'blocks.7.ffn.c_proj.lora_a', 'blocks.7.ffn.c_proj.lora_b', 'blocks.8.attn.c_att.lora_a', 'blocks.8.attn.c_att.lora_b', 'blocks.8.attn.c_proj.lora_a', 'blocks.8.attn.c_proj.lora_b', 'blocks.8.ffn.c_fc.lora_a', 'blocks.8.ffn.c_fc.lora_b', 'blocks.8.ffn.c_proj.lora_a', 'blocks.8.ffn.c_proj.lora_b', 'blocks.9.attn.c_att.lora_a', 'blocks.9.attn.c_att.lora_b', 'blocks.9.attn.c_proj.lora_a', 'blocks.9.attn.c_proj.lora_b', 'blocks.9.ffn.c_fc.lora_a', 'blocks.9.ffn.c_fc.lora_b', 'blocks.9.ffn.c_proj.lora_a', 'blocks.9.ffn.c_proj.lora_b', 'blocks.10.attn.c_att.lora_a', 'blocks.10.attn.c_att.lora_b', 'blocks.10.attn.c_proj.lora_a', 'blocks.10.attn.c_proj.lora_b', 'blocks.10.ffn.c_fc.lora_a', 'blocks.10.ffn.c_fc.lora_b', 'blocks.10.ffn.c_proj.lora_a', 'blocks.10.ffn.c_proj.lora_b', 'blocks.11.attn.c_att.lora_a', 'blocks.11.attn.c_att.lora_b', 'blocks.11.attn.c_proj.lora_a', 'blocks.11.attn.c_proj.lora_b', 'blocks.11.ffn.c_fc.lora_a', 'blocks.11.ffn.c_fc.lora_b', 'blocks.11.ffn.c_proj.lora_a', 'blocks.11.ffn.c_proj.lora_b', 'lm_head.lora_a', 'lm_head.lora_b']\n"
- ]
- }
- ],
- "execution_count": 3
+ "outputs": [],
+ "execution_count": null
},
{
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:23:00.447976Z",
- "start_time": "2024-07-31T12:22:58.566527Z"
- }
- },
+ "metadata": {},
"cell_type": "code",
"source": [
"prompt = \"hello how are you\"\n",
@@ -96,32 +61,16 @@
" print(tokenizer.decode(id))"
],
"id": "f4f7826ec3729b66",
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- ",\n",
- " to\n",
- " you\n",
- " doing\n"
- ]
- }
- ],
- "execution_count": 4
+ "outputs": [],
+ "execution_count": null
},
{
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:23:00.452060Z",
- "start_time": "2024-07-31T12:23:00.448904Z"
- }
- },
+ "metadata": {},
"cell_type": "code",
"source": "",
"id": "c12776360008a974",
"outputs": [],
- "execution_count": 4
+ "execution_count": null
}
],
"metadata": {
diff --git a/labml_nn/transformers/LoRA/train.ipynb b/labml_nn/transformers/LoRA/train.ipynb
index cd70bfb3..b2e3038e 100644
--- a/labml_nn/transformers/LoRA/train.ipynb
+++ b/labml_nn/transformers/LoRA/train.ipynb
@@ -7,41 +7,27 @@
"collapsed": true,
"jupyter": {
"outputs_hidden": true
- },
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:37.296030Z",
- "start_time": "2024-07-31T12:57:37.292368Z"
}
},
"source": "# !wget https://raw.github/zusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
"outputs": [],
- "execution_count": 1
+ "execution_count": null
},
{
"cell_type": "code",
"id": "3b1e507015ba6b81",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:37.317651Z",
- "start_time": "2024-07-31T12:57:37.313808Z"
- }
- },
+ "metadata": {},
"source": [
"with open('input.txt', 'r', encoding='utf-8') as f:\n",
" text = f.read()"
],
"outputs": [],
- "execution_count": 2
+ "execution_count": null
},
{
"cell_type": "code",
"id": "ac8e51ae5bbfcae7",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:40.488939Z",
- "start_time": "2024-07-31T12:57:37.319486Z"
- }
- },
+ "metadata": {},
"source": [
"from transformers import AutoTokenizer\n",
"\n",
@@ -49,75 +35,47 @@
"\n",
"tokens = tokenizer.encode(text, add_special_tokens=False)"
],
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Token indices sequence length is longer than the specified maximum sequence length for this model (338025 > 1024). Running this sequence through the model will result in indexing errors\n"
- ]
- }
- ],
- "execution_count": 3
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
"id": "aeefcdf813e427e",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:40.495510Z",
- "start_time": "2024-07-31T12:57:40.490341Z"
- }
- },
+ "metadata": {},
"source": [
"context_length = 512\n",
"batch_size = 2"
],
"outputs": [],
- "execution_count": 4
+ "execution_count": null
},
{
"cell_type": "code",
"id": "a384b42274f008a2",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:40.522050Z",
- "start_time": "2024-07-31T12:57:40.496842Z"
- }
- },
+ "metadata": {},
"source": [
"num_batches = len(tokens) // (batch_size * context_length)\n",
"tokens = tokens[:num_batches * batch_size * context_length]"
],
"outputs": [],
- "execution_count": 5
+ "execution_count": null
},
{
"cell_type": "code",
"id": "5c4cc78ac1a02c1d",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:40.592272Z",
- "start_time": "2024-07-31T12:57:40.524063Z"
- }
- },
+ "metadata": {},
"source": [
"import torch\n",
"\n",
"input_ids = torch.tensor(tokens).view(-1, context_length)"
],
"outputs": [],
- "execution_count": 6
+ "execution_count": null
},
{
"cell_type": "code",
"id": "7037fd75e2161382",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:40.601199Z",
- "start_time": "2024-07-31T12:57:40.593250Z"
- }
- },
+ "metadata": {},
"source": [
"from torch.utils.data import DataLoader, TensorDataset\n",
"from torch.optim import Adam\n",
@@ -137,17 +95,12 @@
"test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
],
"outputs": [],
- "execution_count": 7
+ "execution_count": null
},
{
"cell_type": "code",
"id": "a98b7baa064b8494",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:41.577878Z",
- "start_time": "2024-07-31T12:57:40.602187Z"
- }
- },
+ "metadata": {},
"source": [
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
"\n",
@@ -157,15 +110,10 @@
"_ = model.load_state_dict(state_dict, strict=False)"
],
"outputs": [],
- "execution_count": 8
+ "execution_count": null
},
{
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:43.098187Z",
- "start_time": "2024-07-31T12:57:41.578713Z"
- }
- },
+ "metadata": {},
"cell_type": "code",
"source": [
"device = \"cuda\"\n",
@@ -173,17 +121,12 @@
],
"id": "2e0fa8b3082df716",
"outputs": [],
- "execution_count": 9
+ "execution_count": null
},
{
"cell_type": "code",
"id": "e2f5076894770740",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:57.044755Z",
- "start_time": "2024-07-31T12:57:43.099050Z"
- }
- },
+ "metadata": {},
"source": [
"from labml import tracker, experiment\n",
"\n",
@@ -236,49 +179,13 @@
"\n",
"print(\"Training complete.\")"
],
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ],
- "text/html": [
- "\n",
- "LoRA.GPT2 : 7a14822c4f3c11efad8354ef33f17c7c \n",
- "\t[dirty]: \"training loop\" \n",
- "Monitor experiment at http://localhost:5005/run/7a14822c4f3c11efad8354ef33f17c7c \n",
- "Still updating labml server, please wait for it to complete... "
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
- "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
- "Cell \u001B[0;32mIn[10], line 25\u001B[0m\n\u001B[1;32m 22\u001B[0m loss \u001B[38;5;241m=\u001B[39m criterion(shift_logits\u001B[38;5;241m.\u001B[39mreshape(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m, shift_logits\u001B[38;5;241m.\u001B[39msize(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)), shift_labels\u001B[38;5;241m.\u001B[39mreshape(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m))\n\u001B[1;32m 24\u001B[0m optimizer\u001B[38;5;241m.\u001B[39mzero_grad()\n\u001B[0;32m---> 25\u001B[0m loss\u001B[38;5;241m.\u001B[39mbackward()\n\u001B[1;32m 26\u001B[0m optimizer\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m 28\u001B[0m tracker\u001B[38;5;241m.\u001B[39msave(step, {\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m'\u001B[39m: loss})\n",
- "File \u001B[0;32m~/miniconda3/lib/python3.12/site-packages/torch/_tensor.py:521\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 511\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m 512\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m 513\u001B[0m Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m 514\u001B[0m (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 519\u001B[0m inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m 520\u001B[0m )\n\u001B[0;32m--> 521\u001B[0m torch\u001B[38;5;241m.\u001B[39mautograd\u001B[38;5;241m.\u001B[39mbackward(\n\u001B[1;32m 522\u001B[0m \u001B[38;5;28mself\u001B[39m, gradient, retain_graph, create_graph, inputs\u001B[38;5;241m=\u001B[39minputs\n\u001B[1;32m 523\u001B[0m )\n",
- "File \u001B[0;32m~/miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:289\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m 284\u001B[0m retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m 286\u001B[0m \u001B[38;5;66;03m# The reason we repeat the same comment below is that\u001B[39;00m\n\u001B[1;32m 287\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m 288\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 289\u001B[0m _engine_run_backward(\n\u001B[1;32m 290\u001B[0m tensors,\n\u001B[1;32m 291\u001B[0m grad_tensors_,\n\u001B[1;32m 292\u001B[0m retain_graph,\n\u001B[1;32m 293\u001B[0m create_graph,\n\u001B[1;32m 294\u001B[0m inputs,\n\u001B[1;32m 295\u001B[0m allow_unreachable\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m,\n\u001B[1;32m 296\u001B[0m accumulate_grad\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m,\n\u001B[1;32m 297\u001B[0m )\n",
- "File \u001B[0;32m~/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:768\u001B[0m, in \u001B[0;36m_engine_run_backward\u001B[0;34m(t_outputs, *args, **kwargs)\u001B[0m\n\u001B[1;32m 766\u001B[0m unregister_hooks \u001B[38;5;241m=\u001B[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001B[1;32m 767\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 768\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m Variable\u001B[38;5;241m.\u001B[39m_execution_engine\u001B[38;5;241m.\u001B[39mrun_backward( \u001B[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001B[39;00m\n\u001B[1;32m 769\u001B[0m t_outputs, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs\n\u001B[1;32m 770\u001B[0m ) \u001B[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001B[39;00m\n\u001B[1;32m 771\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[1;32m 772\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m attach_logging_hooks:\n",
- "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
- ]
- }
- ],
- "execution_count": 10
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
"id": "da2d4023002648dc",
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-07-31T12:57:57.046254Z",
- "start_time": "2024-07-31T12:57:57.045954Z"
- }
- },
+ "metadata": {},
"source": [],
"outputs": [],
"execution_count": null
From dc4762161d9aafdb73775b0989f64861c4fd2875 Mon Sep 17 00:00:00 2001
From: Varuna Jayasiri
Date: Fri, 2 Aug 2024 15:32:02 +0530
Subject: [PATCH 14/16] Clean up LoRA
---
.../{transformers/LoRA => lora}/__init__.py | 0
.../LoRA/GPT2.py => lora/gpt2.py} | 2 +-
.../{transformers/LoRA => lora}/train.ipynb | 54 ++++++-----
labml_nn/lora/transform_hf_model.py | 46 +++++++++
labml_nn/{RWKV => rwkv}/__init__.py | 0
labml_nn/{RWKV => rwkv}/configs.py | 0
labml_nn/{RWKV => rwkv}/experiment.py | 6 +-
labml_nn/transformers/LoRA/experiment.ipynb | 97 -------------------
labml_nn/transformers/LoRA/load_hf.py | 44 ---------
9 files changed, 78 insertions(+), 171 deletions(-)
rename labml_nn/{transformers/LoRA => lora}/__init__.py (100%)
rename labml_nn/{transformers/LoRA/GPT2.py => lora/gpt2.py} (98%)
rename labml_nn/{transformers/LoRA => lora}/train.ipynb (93%)
create mode 100644 labml_nn/lora/transform_hf_model.py
rename labml_nn/{RWKV => rwkv}/__init__.py (100%)
rename labml_nn/{RWKV => rwkv}/configs.py (100%)
rename labml_nn/{RWKV => rwkv}/experiment.py (97%)
delete mode 100644 labml_nn/transformers/LoRA/experiment.ipynb
delete mode 100644 labml_nn/transformers/LoRA/load_hf.py
diff --git a/labml_nn/transformers/LoRA/__init__.py b/labml_nn/lora/__init__.py
similarity index 100%
rename from labml_nn/transformers/LoRA/__init__.py
rename to labml_nn/lora/__init__.py
diff --git a/labml_nn/transformers/LoRA/GPT2.py b/labml_nn/lora/gpt2.py
similarity index 98%
rename from labml_nn/transformers/LoRA/GPT2.py
rename to labml_nn/lora/gpt2.py
index a7a59342..a83a0276 100644
--- a/labml_nn/transformers/LoRA/GPT2.py
+++ b/labml_nn/lora/gpt2.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer
-from labml_nn.transformers.LoRA import Linear, Embedding
+from labml_nn.lora import Linear, Embedding
tokenizer = AutoTokenizer.from_pretrained("gpt2")
diff --git a/labml_nn/transformers/LoRA/train.ipynb b/labml_nn/lora/train.ipynb
similarity index 93%
rename from labml_nn/transformers/LoRA/train.ipynb
rename to labml_nn/lora/train.ipynb
index b2e3038e..68bbb7eb 100644
--- a/labml_nn/transformers/LoRA/train.ipynb
+++ b/labml_nn/lora/train.ipynb
@@ -1,5 +1,22 @@
{
"cells": [
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null,
+ "source": [
+ "import torch\n",
+ "from torch.optim import Adam\n",
+ "from torch.utils.data import DataLoader, TensorDataset\n",
+ "from torch.utils.data import random_split\n",
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "from labml import tracker, experiment\n",
+ "from labml_nn.lora.gpt2 import GPTModel"
+ ],
+ "id": "f072832ec9d346e1"
+ },
{
"cell_type": "code",
"id": "initial_id",
@@ -29,8 +46,6 @@
"id": "ac8e51ae5bbfcae7",
"metadata": {},
"source": [
- "from transformers import AutoTokenizer\n",
- "\n",
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
"\n",
"tokens = tokenizer.encode(text, add_special_tokens=False)"
@@ -64,11 +79,7 @@
"cell_type": "code",
"id": "5c4cc78ac1a02c1d",
"metadata": {},
- "source": [
- "import torch\n",
- "\n",
- "input_ids = torch.tensor(tokens).view(-1, context_length)"
- ],
+ "source": "input_ids = torch.tensor(tokens).view(-1, context_length)",
"outputs": [],
"execution_count": null
},
@@ -77,10 +88,6 @@
"id": "7037fd75e2161382",
"metadata": {},
"source": [
- "from torch.utils.data import DataLoader, TensorDataset\n",
- "from torch.optim import Adam\n",
- "from torch.utils.data import random_split\n",
- "\n",
"dataset = TensorDataset(input_ids)\n",
"\n",
"train_ratio = 0.8\n",
@@ -102,8 +109,6 @@
"id": "a98b7baa064b8494",
"metadata": {},
"source": [
- "from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
- "\n",
"model = GPTModel()\n",
"state_dict = torch.load('transformed.pth', weights_only=True)\n",
"\n",
@@ -128,8 +133,6 @@
"id": "e2f5076894770740",
"metadata": {},
"source": [
- "from labml import tracker, experiment\n",
- "\n",
"optimizer = Adam(model.parameters(), lr=5e-5)\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"\n",
@@ -143,39 +146,38 @@
" inputs = batch[0]\n",
" inputs = inputs.to(device)\n",
" labels = inputs.clone()\n",
- " \n",
+ "\n",
" outputs = model(inputs)\n",
- " \n",
+ "\n",
" shift_logits = outputs[..., :-1, :]\n",
" shift_labels = labels[..., 1:]\n",
- " \n",
+ "\n",
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
- " \n",
+ "\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
- " \n",
+ "\n",
" tracker.save(step, {'loss': loss})\n",
" step += 1\n",
" print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n",
- " \n",
+ "\n",
" test_loss = 0\n",
" for batch in test_dataloader:\n",
" inputs = batch[0]\n",
" inputs = inputs.to(device)\n",
" labels = inputs.clone()\n",
- " \n",
+ "\n",
" outputs = model(inputs)\n",
- " \n",
+ "\n",
" shift_logits = outputs[..., :-1, :]\n",
" shift_labels = labels[..., 1:]\n",
- " \n",
+ "\n",
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
- " \n",
+ "\n",
" test_loss += loss.item()\n",
" test_loss /= len(test_dataloader)\n",
" tracker.save(step, {'test_loss': test_loss})\n",
- " \n",
"\n",
"print(\"Training complete.\")"
],
diff --git a/labml_nn/lora/transform_hf_model.py b/labml_nn/lora/transform_hf_model.py
new file mode 100644
index 00000000..df53bbf2
--- /dev/null
+++ b/labml_nn/lora/transform_hf_model.py
@@ -0,0 +1,46 @@
+import torch
+from transformers import AutoModelForCausalLM
+
+
+def transform_hf_model():
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
+
+ state_dict = model.state_dict()
+
+ mapping = {
+ 'transformer.wte.weight': 'token_embedding.weight',
+ 'transformer.wpe.weight': 'position_embedding.weight',
+ 'transformer.ln_f.weight': 'final_norm.weight',
+ 'transformer.ln_f.bias': 'final_norm.bias',
+ 'lm_head.weight': 'lm_head.weight'
+ }
+
+ for i in range(12):
+ mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
+ mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
+ mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
+ mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
+ mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
+ mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
+ mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
+ mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
+ mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
+ mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
+ mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
+ mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
+
+ new_state_dict = {}
+ for old_key, new_key in mapping.items():
+ if old_key in state_dict:
+ new_state_dict[new_key] = state_dict[old_key]
+
+ # transpose weight matrices of convo 1d layers to use linear layers instead
+ convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
+ [f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
+ [f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
+ [f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
+
+ for layer in convo_layers:
+ new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
+
+ torch.save(new_state_dict, 'transformed.pth')
diff --git a/labml_nn/RWKV/__init__.py b/labml_nn/rwkv/__init__.py
similarity index 100%
rename from labml_nn/RWKV/__init__.py
rename to labml_nn/rwkv/__init__.py
diff --git a/labml_nn/RWKV/configs.py b/labml_nn/rwkv/configs.py
similarity index 100%
rename from labml_nn/RWKV/configs.py
rename to labml_nn/rwkv/configs.py
diff --git a/labml_nn/RWKV/experiment.py b/labml_nn/rwkv/experiment.py
similarity index 97%
rename from labml_nn/RWKV/experiment.py
rename to labml_nn/rwkv/experiment.py
index 1f99d66d..983db2c0 100644
--- a/labml_nn/RWKV/experiment.py
+++ b/labml_nn/rwkv/experiment.py
@@ -3,10 +3,10 @@ import math
import torch
import torch.nn as nn
-from labml_nn.RWKV.configs import RWKVConfigs
+from labml_nn.rwkv.configs import RWKVConfigs
-from labml_nn.RWKV import RWKV
-from labml_nn.RWKV import TimeMixing
+from labml_nn.rwkv import RWKV
+from labml_nn.rwkv import TimeMixing
from labml import experiment
from labml.configs import option
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
diff --git a/labml_nn/transformers/LoRA/experiment.ipynb b/labml_nn/transformers/LoRA/experiment.ipynb
deleted file mode 100644
index f0ae1c84..00000000
--- a/labml_nn/transformers/LoRA/experiment.ipynb
+++ /dev/null
@@ -1,97 +0,0 @@
-{
- "cells": [
- {
- "metadata": {},
- "cell_type": "code",
- "source": [
- "from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
- "import torch"
- ],
- "id": "cffa3ec341b4905a",
- "outputs": [],
- "execution_count": null
- },
- {
- "metadata": {},
- "cell_type": "code",
- "source": [
- "from transformers import AutoTokenizer\n",
- "\n",
- "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
- ],
- "id": "c2b0b7e18394ea9e",
- "outputs": [],
- "execution_count": null
- },
- {
- "cell_type": "code",
- "id": "initial_id",
- "metadata": {
- "collapsed": true
- },
- "source": [
- "model = GPTModel()\n",
- "\n",
- "state_dict = torch.load('transformed.pth')\n",
- "\n",
- "missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n",
- "if missing_keys:\n",
- " print(f\"Missing keys: {missing_keys}\")\n",
- "if unexpected_keys:\n",
- " print(f\"Unexpected keys: {unexpected_keys}\")"
- ],
- "outputs": [],
- "execution_count": null
- },
- {
- "metadata": {},
- "cell_type": "code",
- "source": [
- "prompt = \"hello how are you\"\n",
- "tokenized = tokenizer(prompt, return_tensors=\"pt\")\n",
- "tokenized['input_ids'] = tokenized['input_ids'].to('cuda')\n",
- "model = model.to('cuda')\n",
- "\n",
- "with torch.no_grad():\n",
- " model.eval()\n",
- " res = model(tokenized['input_ids'])\n",
- "\n",
- "output_ids = torch.argmax(res, dim=-1)\n",
- "for id in output_ids[0]:\n",
- " print(tokenizer.decode(id))"
- ],
- "id": "f4f7826ec3729b66",
- "outputs": [],
- "execution_count": null
- },
- {
- "metadata": {},
- "cell_type": "code",
- "source": "",
- "id": "c12776360008a974",
- "outputs": [],
- "execution_count": null
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/labml_nn/transformers/LoRA/load_hf.py b/labml_nn/transformers/LoRA/load_hf.py
deleted file mode 100644
index 0e8ff6be..00000000
--- a/labml_nn/transformers/LoRA/load_hf.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import torch
-from transformers import AutoModelForCausalLM
-
-model = AutoModelForCausalLM.from_pretrained("gpt2")
-
-state_dict = model.state_dict()
-
-mapping = {
- 'transformer.wte.weight': 'token_embedding.weight',
- 'transformer.wpe.weight': 'position_embedding.weight',
- 'transformer.ln_f.weight': 'final_norm.weight',
- 'transformer.ln_f.bias': 'final_norm.bias',
- 'lm_head.weight': 'lm_head.weight'
-}
-
-for i in range(12):
- mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
- mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
- mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
- mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
- mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
- mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
- mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
- mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
- mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
- mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
- mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
- mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
-
-new_state_dict = {}
-for old_key, new_key in mapping.items():
- if old_key in state_dict:
- new_state_dict[new_key] = state_dict[old_key]
-
-# transpose weight matrices of convo 1d layers to use linear layers instead
-convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
- [f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
- [f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
- [f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
-
-for layer in convo_layers:
- new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
-
-torch.save(new_state_dict, 'transformed.pth')
From eb9337e949961c0b0352f763ab52aef2abac73de Mon Sep 17 00:00:00 2001
From: Varuna Jayasiri
Date: Fri, 2 Aug 2024 15:33:45 +0530
Subject: [PATCH 15/16] Clean up LoRA
---
labml_nn/lora/__init__.py | 26 ++++++++++----------------
1 file changed, 10 insertions(+), 16 deletions(-)
diff --git a/labml_nn/lora/__init__.py b/labml_nn/lora/__init__.py
index 302a4bf9..9124ebc9 100644
--- a/labml_nn/lora/__init__.py
+++ b/labml_nn/lora/__init__.py
@@ -1,18 +1,17 @@
+"""
+# LoRA
+"""
+
import torch
import torch.nn as nn
class Linear(nn.Module):
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool,
- r: int,
- alpha: int = None):
+ def __init__(self, in_features: int, out_features: int, bias: bool,
+ r: int, alpha: int = None):
+ super().__init__()
if alpha is None:
alpha = r
- super().__init__()
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
self.weight.requires_grad = False
@@ -39,16 +38,11 @@ class Linear(nn.Module):
class Embedding(nn.Module):
- def __init__(
- self,
- num_embeddings: int,
- embedding_dim: int,
- r: int,
- alpha: int = None,
- ):
+ def __init__(self, num_embeddings: int, embedding_dim: int,
+ r: int, alpha: int = None):
+ super().__init__()
if alpha is None:
alpha = r
- super().__init__()
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
self.weight.requires_grad = False
From d4af40b595ebd7e1eb7fd872c02cc0911cb23bb4 Mon Sep 17 00:00:00 2001
From: Varuna Jayasiri
Date: Sat, 3 Aug 2024 16:59:15 +0530
Subject: [PATCH 16/16] LoRA notes
---
docs/RWKV/configs.html | 8 +-
docs/RWKV/experiment.html | 14 +-
docs/RWKV/index.html | 8 +-
docs/gan/wasserstein/index.html | 10 +-
docs/lora/gpt2.html | 378 +++++++++++++++++++++
docs/lora/index.html | 534 ++++++++++++++++++++++++++++++
docs/lora/transform_hf_model.html | 186 +++++++++++
docs/sitemap.xml | 35 +-
labml_nn/lora/__init__.py | 91 ++++-
9 files changed, 1236 insertions(+), 28 deletions(-)
create mode 100644 docs/lora/gpt2.html
create mode 100644 docs/lora/index.html
create mode 100644 docs/lora/transform_hf_model.html
diff --git a/docs/RWKV/configs.html b/docs/RWKV/configs.html
index 3780bb86..463c144a 100644
--- a/docs/RWKV/configs.html
+++ b/docs/RWKV/configs.html
@@ -12,7 +12,7 @@
-
+
@@ -23,7 +23,7 @@
configs.py
-
+
@@ -47,7 +47,7 @@
diff --git a/docs/RWKV/experiment.html b/docs/RWKV/experiment.html
index 71698823..281bcac1 100644
--- a/docs/RWKV/experiment.html
+++ b/docs/RWKV/experiment.html
@@ -12,7 +12,7 @@
-
+
@@ -23,7 +23,7 @@
experiment.py
-
+
@@ -47,7 +47,7 @@
@@ -78,10 +78,10 @@
3
4 import torch
5 import torch.nn as nn
-6 from labml_nn.RWKV.configs import RWKVConfigs
+6 from labml_nn.rwkv.configs import RWKVConfigs
7
-8 from labml_nn.RWKV import RWKV
-9 from labml_nn.RWKV import TimeMixing
+8 from labml_nn.rwkv import RWKV
+9 from labml_nn.rwkv import TimeMixing
10 from labml import experiment
11 from labml.configs import option
12 from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
diff --git a/docs/RWKV/index.html b/docs/RWKV/index.html
index cb73300b..5462e088 100644
--- a/docs/RWKV/index.html
+++ b/docs/RWKV/index.html
@@ -12,7 +12,7 @@
-
+
@@ -23,7 +23,7 @@
Receptance Weighted Key Value (RWKV)
-
+
@@ -47,7 +47,7 @@
diff --git a/docs/gan/wasserstein/index.html b/docs/gan/wasserstein/index.html
index 72bb41c4..b3a28135 100644
--- a/docs/gan/wasserstein/index.html
+++ b/docs/gan/wasserstein/index.html
@@ -74,17 +74,17 @@
Wasserstein GAN (WGAN)
This is an implementation of Wasserstein GAN .
The original GAN loss is based on Jensen-Shannon (JS) divergence between the real distribution P r and generated distribution P g . The Wasserstein GAN is based on Earth Mover distance between these distributions.
-W ( P r , P g ) = γ ∈ Π ( P r , P g ) in f E ( x , y ) ∼ γ ∥ x − y ∥
+W ( P r , P g ) = γ ∈ Π ( P r , P g ) in f E ( x , y ) ∼ γ ∥ x − y ∥
Π ( P r , P g ) is the set of all joint distributions, whose marginal probabilities are γ ( x , y ) .
E ( x , y ) ∼ γ ∥ x − y ∥ is the earth mover distance for a given joint distribution (x and y are probabilities).
-So W ( P r , P g ) is equal to the least earth mover distance for any joint distribution between the real distribution P r and generated distribution P g .
+So W ( P r , P g ) is equal to the least earth mover distance for any joint distribution between the real distribution P r and generated distribution P g .
The paper shows that Jensen-Shannon (JS) divergence and other measures for the difference between two probability distributions are not smooth. And therefore if we are doing gradient descent on one of the probability distributions (parameterized) it will not converge.
-Based on Kantorovich-Rubinstein duality, W ( P r , P g ) = ∥ f ∥ L ≤ 1 sup E x ∼ P r [ f ( x )] − E x ∼ P g [ f ( x )]
+Based on Kantorovich-Rubinstein duality, W ( P r , P g ) = ∥ f ∥ L ≤ 1 sup E x ∼ P r [ f ( x )] − E x ∼ P g [ f ( x )]
where ∥ f ∥ L ≤ 1 are all 1-Lipschitz functions.
That is, it is equal to the greatest difference E x ∼ P r [ f ( x )] − E x ∼ P g [ f ( x )] among all 1-Lipschitz functions.
-For K -Lipschitz functions, W ( P r , P g ) = ∥ f ∥ L ≤ K sup E x ∼ P r [ K 1 f ( x ) ] − E x ∼ P g [ K 1 f ( x ) ]
+For K -Lipschitz functions, W ( P r , P g ) = ∥ f ∥ L ≤ K sup E x ∼ P r [ K 1 f ( x ) ] − E x ∼ P g [ K 1 f ( x ) ]
If all K -Lipschitz functions can be represented as f w where f is parameterized by w ∈ W ,
-K ⋅ W ( P r , P g ) = w ∈ W m a x E x ∼ P r [ f w ( x ) ] − E x ∼ P g [ f w ( x ) ]
+K ⋅ W ( P r , P g ) = w ∈ W m a x E x ∼ P r [ f w ( x ) ] − E x ∼ P g [ f w ( x ) ]
If ( P g ) is represented by a generator g θ ( z ) and z is from a known distribution z ∼ p ( z ) ,
K ⋅ W ( P r , P θ ) = w ∈ W m a x E x ∼ P r [ f w ( x ) ] − E z ∼ p ( z ) [ f w ( g θ ( z )) ]
Now to converge g θ with P r we can gradient descent on θ to minimize above formula.
diff --git a/docs/lora/gpt2.html b/docs/lora/gpt2.html
new file mode 100644
index 00000000..bed238dc
--- /dev/null
+++ b/docs/lora/gpt2.html
@@ -0,0 +1,378 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ gpt2.py
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
1 import torch
+2 import torch.nn as nn
+3 from transformers import AutoTokenizer
+4 from labml_nn.lora import Linear , Embedding
+5
+6 tokenizer = AutoTokenizer . from_pretrained ( "gpt2" )
+7
+8 config = {
+9 "layer_norm_epsilon" : 1e-05 ,
+10 "n_embd" : 768 ,
+11 "n_head" : 12 ,
+12 "n_layer" : 12 ,
+13 "n_positions" : 1024 ,
+14 "vocab_size" : 50257 ,
+15 "device" : "cuda"
+16 }
+
+
+
+
+
+
+
20 def __init__ ( self , dim ):
+21 super () . __init__ ()
+22 self . c_fc = Linear ( config [ 'n_embd' ], dim , r = 32 , bias = True )
+23 self . c_proj = Linear ( dim , config [ 'n_embd' ], r = 32 , bias = True )
+24 self . act = nn . functional . gelu
+
+
+
+
+
+
26 def forward ( self , hidden_states ):
+27 hidden_states = self . c_fc ( hidden_states )
+28 hidden_states = self . act ( hidden_states )
+29 hidden_states = self . c_proj ( hidden_states )
+30 return hidden_states
+
+
+
+
+
+
33 class MultiHeadAttention ( nn . Module ):
+
+
+
+
+
+
34 def __init__ ( self ):
+35 super () . __init__ ()
+36 self . embed_dim = config [ 'n_embd' ]
+37 self . num_heads = config [ 'n_head' ]
+38 self . head_dim = self . embed_dim // self . num_heads
+39 self . split_size = self . embed_dim
+40
+41 self . c_att = Linear ( config [ 'n_embd' ], config [ 'n_embd' ] * 3 , r = 32 , bias = True )
+42 self . c_proj = Linear ( config [ 'n_embd' ], config [ 'n_embd' ], r = 32 , bias = True )
+
+
+
+
+
+
Splits hidden_size dim into attn_head_size and num_heads
+
+
+
+
44 def _split_heads ( self , tensor , num_heads , attn_head_size ):
+
+
+
+
+
+
48 new_shape = tensor . size ()[: - 1 ] + ( num_heads , attn_head_size )
+49 tensor = tensor . view ( new_shape )
+50 return tensor . permute ( 0 , 2 , 1 , 3 ) # (batch, head, seq_length, head_features)
+
+
+
+
+
+
52 def forward ( self , hidden_states ):
+53 batch_size , seq_length , _ = hidden_states . size ()
+54
+55 query , key , value = self . c_att ( hidden_states ) . split ( self . split_size , dim = 2 )
+56
+57 query = self . _split_heads ( query , self . num_heads , self . head_dim )
+58 key = self . _split_heads ( key , self . num_heads , self . head_dim )
+59 value = self . _split_heads ( value , self . num_heads , self . head_dim )
+60
+61 attn_output = torch . nn . functional . scaled_dot_product_attention (
+62 query ,
+63 key ,
+64 value ,
+65 attn_mask = None ,
+66 dropout_p = 0.0 ,
+67 is_causal = True , # for the triangular mask
+68 )
+69
+70 attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ()
+71 attn_output = attn_output . view ( batch_size , seq_length , self . embed_dim )
+72
+73 attn_output = self . c_proj ( attn_output )
+74
+75 return attn_output
+
+
+
+
+
+
78 class Block ( nn . Module ):
+
+
+
+
+
+
79 def __init__ ( self ):
+80 super () . __init__ ()
+81 self . pre_norm = nn . LayerNorm ( config [ 'n_embd' ], eps = config [ 'layer_norm_epsilon' ])
+82 self . attn = MultiHeadAttention ()
+83 self . post_norm = nn . LayerNorm ( config [ 'n_embd' ], eps = config [ 'layer_norm_epsilon' ])
+84 self . ffn = FFN ( config [ 'n_embd' ] * 4 )
+
+
+
+
+
+
86 def forward ( self , hidden_states ):
+87 residual = hidden_states
+88 hidden_states = self . pre_norm ( hidden_states )
+89
+90 attn_output = self . attn ( hidden_states )
+91
+92 hidden_states = attn_output + residual
+93 residual = hidden_states
+94 hidden_states = self . post_norm ( hidden_states )
+95 feed_forward_output = self . ffn ( hidden_states )
+96 hidden_states = feed_forward_output + residual
+97
+98 return hidden_states
+
+
+
+
+
+
101 class GPTModel ( nn . Module ):
+
+
+
+
+
+
102 def __init__ ( self ):
+103 super () . __init__ ()
+104
+105 self . token_embedding = Embedding ( config [ 'vocab_size' ], config [ 'n_embd' ], r = 32 )
+106 self . position_embedding = Embedding ( config [ 'n_positions' ], config [ 'n_embd' ], r = 32 )
+107
+108 self . blocks = nn . ModuleList ([ Block () for _ in range ( config [ 'n_layer' ])])
+109
+110 self . final_norm = nn . LayerNorm ( config [ 'n_embd' ], eps = config [ 'layer_norm_epsilon' ])
+111
+112 self . lm_head = Linear ( config [ 'n_embd' ], config [ 'vocab_size' ], r = 32 , bias = False )
+
+
+
+
+
+
114 def forward ( self , input_ids ):
+115 batch_size , input_shape = input_ids . size ()
+116
+117 token_embeddings = self . token_embedding ( input_ids ) # B T C
+118 position_ids = torch . arange ( input_shape , device = config [ 'device' ]) # T C
+119 position_embeddings = self . position_embedding ( position_ids ) # B T C
+120
+121 hidden_states = token_embeddings + position_embeddings
+122
+123 for block in self . blocks :
+124 hidden_states = block ( hidden_states )
+125
+126 hidden_states = self . final_norm ( hidden_states )
+127
+128 logits = self . lm_head ( hidden_states )
+129
+130 return logits
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/lora/index.html b/docs/lora/index.html
new file mode 100644
index 00000000..46d25217
--- /dev/null
+++ b/docs/lora/index.html
@@ -0,0 +1,534 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Low-Rank Adaptation (LoRA)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Low-Rank Adaptation (LoRA)
+
This is an implementation of Low-Rank Adaptation (LoRA) in PyTorch .
+
Low-Rank Adaptation (LoRA) freezes pre-trained model weights and injects trainable rank decomposition matrices into each layer of the transformer. This makes it possible to efficiently fine-tune large langauge models by reducing trainable parameters by a large factor.
+
Here's the training code for training a GPT2 model with LoRA on Tiny Shakespeare dataset.
+
+
+
+
24 import torch
+25 import torch.nn as nn
+
+
+
+
+
+
LoRA Linear Layer
+
LoRA linear layer adds a low-rank decomposition to the pre-trained weight matrix (W 0 ∈ R d × k ) of the linear layer.
+
W 0 + Δ W = W 0 + B A
+
, where B ∈ R d × r , A ∈ R r × k , and the rank r ≪ min ( d , k ) .
+
All parameters are frozen except A and B .
+
Δ W is initialized to be zero at the beginning of the training.
+
They multiple Δ W x by r α where α is a hyper-parameter. Once α is tuned it can be kept the same when varying r .
+
+
+
+
28 class Linear ( nn . Module ):
+
+
+
+
+
+
in_features
+ is the number of input features of the linear layer
+out_features
+ is the number of output features of the linear layer
+bias
+ is a flag indicating if there is a bias parameter
+r
+ is the rank of the decomposition r
+alpha
+ is the scaling factor α
+
+
+
+
49 def __init__ ( self , in_features : int , out_features : int , bias : bool ,
+50 r : int , alpha : int = None ):
+
+
+
+
+
+
+
Set α = r is not provided. i.e. make the scaling factor r α = 1 .
+
+
+
+
61 if alpha is None :
+62 alpha = r
+
+
+
+
+
+
The pre-trained weight W 0
+
+
+
+
65 self . weight = nn . Parameter ( torch . empty (( out_features , in_features )))
+
+
+
+
+
+
67 self . weight . requires_grad = False
+68
+69 if bias :
+
+
+
+
+
+
Bias parameter b 0 (also frozen)
+
+
+
+
71 self . bias = nn . Parameter ( torch . empty ( out_features ))
+72 self . bias . requires_grad = False
+73 else :
+
+
+
+
+
+
No bias parameter
+
+
+
+
+
+
+
+
scaling factor r α
+
+
+
+
78 self . scaling = alpha / r
+
+
+
+
+
+
80 self . lora_a = nn . Parameter ( torch . empty (( in_features , r )))
+
+
+
+
+
+
Matrix B ∈ R d × r , we keep A and B transposed
+
+
+
+
82 self . lora_b = nn . Parameter ( torch . empty (( r , out_features )))
+83
+84 with torch . no_grad ():
+
+
+
+
+
+
Initialize A similar to a weight matrix in a normal linear layer
+
+
+
+
86 nn . init . kaiming_uniform_ ( self . lora_a , a = 5 ** 0.5 )
+
+
+
+
+
+
Initialize B to 0 so that Δ W = B A is 0 at initialization
+
+
+
+
88 nn . init . zeros_ ( self . lora_b )
+
+
+
+
+
+
90 def forward ( self , x : torch . Tensor ):
+
+
+
+
+
+
92 result = nn . functional . linear ( x , self . weight , bias = self . bias )
+
+
+
+
+
+
Add r α Δ W x = r α B A x
+
+
+
+
95 result += ( x @ self . lora_a @ self . lora_b ) * self . scaling
+
+
+
+
+
+
+
LoRA Embedding Layer
+
Similar to LoRA linear layer this adds a low-rank decomposition to the pre-trained embedding weights matrix (W 0 ∈ R d × k ).
+
W 0 + Δ W = W 0 + B A
+
+
+
+
101 class Embedding ( nn . Module ):
+
+
+
+
+
+
num_embeddings
+ is the number of embeddings
+embedding_dim
+ is the number embedding dimensions
+r
+ is the rank of the decomposition r
+alpha
+ is the scaling factor α
+
+
+
+
111 def __init__ ( self , num_embeddings : int , embedding_dim : int ,
+112 r : int , alpha : int = None ):
+
+
+
+
+
+
+
Set α = r is not provided. i.e. make the scaling factor r α = 1 .
+
+
+
+
123 if alpha is None :
+124 alpha = r
+
+
+
+
+
+
The pre-trained embedding weights W 0 (frozen)
+
+
+
+
127 self . weight = nn . Parameter ( torch . empty (( num_embeddings , embedding_dim )))
+128 self . weight . requires_grad = False
+
+
+
+
+
+
scaling factor r α
+
+
+
+
131 self . scaling = alpha / r
+
+
+
+
+
+
133 self . lora_a = nn . Parameter ( torch . empty (( num_embeddings , r )))
+
+
+
+
+
+
135 self . lora_b = nn . Parameter ( torch . empty (( r , embedding_dim )))
+136
+137 with torch . no_grad ():
+
+
+
+
+
+
Initialize A with a normal distribution
+
+
+
+
139 nn . init . normal_ ( self . lora_a )
+
+
+
+
+
+
Initialize B to 0 so that Δ W = B A is 0 at initialization
+
+
+
+
141 nn . init . zeros_ ( self . lora_b )
+
+
+
+
+
+
143 def forward ( self , x : torch . Tensor ):
+
+
+
+
+
+
Compute the embeddings W 0 onehot ( x )
+
+
+
+
145 result = nn . functional . embedding ( x , self . weight )
+
+
+
+
+
+
148 result += ( nn . functional . embedding ( x , self . lora_a ) @ self . lora_b ) * self . scaling
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/lora/transform_hf_model.html b/docs/lora/transform_hf_model.html
new file mode 100644
index 00000000..a9d34c3a
--- /dev/null
+++ b/docs/lora/transform_hf_model.html
@@ -0,0 +1,186 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ transform_hf_model.py
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
1 import torch
+2 from transformers import AutoModelForCausalLM
+
+
+
+
+
+
5 def transform_hf_model ():
+6 model = AutoModelForCausalLM . from_pretrained ( "gpt2" )
+7
+8 state_dict = model . state_dict ()
+9
+10 mapping = {
+11 'transformer.wte.weight' : 'token_embedding.weight' ,
+12 'transformer.wpe.weight' : 'position_embedding.weight' ,
+13 'transformer.ln_f.weight' : 'final_norm.weight' ,
+14 'transformer.ln_f.bias' : 'final_norm.bias' ,
+15 'lm_head.weight' : 'lm_head.weight'
+16 }
+17
+18 for i in range ( 12 ):
+19 mapping [ f 'transformer.h. { i } .ln_1.weight' ] = f 'blocks. { i } .pre_norm.weight'
+20 mapping [ f 'transformer.h. { i } .ln_1.bias' ] = f 'blocks. { i } .pre_norm.bias'
+21 mapping [ f 'transformer.h. { i } .attn.c_attn.weight' ] = f 'blocks. { i } .attn.c_att.weight'
+22 mapping [ f 'transformer.h. { i } .attn.c_attn.bias' ] = f 'blocks. { i } .attn.c_att.bias'
+23 mapping [ f 'transformer.h. { i } .attn.c_proj.weight' ] = f 'blocks. { i } .attn.c_proj.weight'
+24 mapping [ f 'transformer.h. { i } .attn.c_proj.bias' ] = f 'blocks. { i } .attn.c_proj.bias'
+25 mapping [ f 'transformer.h. { i } .ln_2.weight' ] = f 'blocks. { i } .post_norm.weight'
+26 mapping [ f 'transformer.h. { i } .ln_2.bias' ] = f 'blocks. { i } .post_norm.bias'
+27 mapping [ f 'transformer.h. { i } .mlp.c_fc.weight' ] = f 'blocks. { i } .ffn.c_fc.weight'
+28 mapping [ f 'transformer.h. { i } .mlp.c_fc.bias' ] = f 'blocks. { i } .ffn.c_fc.bias'
+29 mapping [ f 'transformer.h. { i } .mlp.c_proj.weight' ] = f 'blocks. { i } .ffn.c_proj.weight'
+30 mapping [ f 'transformer.h. { i } .mlp.c_proj.bias' ] = f 'blocks. { i } .ffn.c_proj.bias'
+31
+32 new_state_dict = {}
+33 for old_key , new_key in mapping . items ():
+34 if old_key in state_dict :
+35 new_state_dict [ new_key ] = state_dict [ old_key ]
+
+
+
+
+
+
transpose weight matrices of convo 1d layers to use linear layers instead
+
+
+
+
38 convo_layers = ([ f 'blocks. { i } .ffn.c_fc.weight' for i in range ( 12 )] +
+39 [ f 'blocks. { i } .ffn.c_proj.weight' for i in range ( 12 )] +
+40 [ f 'blocks. { i } .attn.c_att.weight' for i in range ( 12 )] +
+41 [ f 'blocks. { i } .attn.c_proj.weight' for i in range ( 12 )])
+42
+43 for layer in convo_layers :
+44 new_state_dict [ layer ] = torch . transpose ( new_state_dict [ layer ], 0 , 1 )
+45
+46 torch . save ( new_state_dict , 'transformed.pth' )
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 7b46859e..d7cc9aff 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -8,7 +8,7 @@
https://nn.labml.ai/gan/wasserstein/index.html
- 2023-10-24T16:30:00+00:00
+ 2024-07-15T16:30:00+00:00
1.00
@@ -504,22 +504,22 @@
- https://nn.labml.ai/RWKV/configs.html
- 2024-03-17T16:30:00+00:00
+ https://nn.labml.ai/rwkv/configs.html
+ 2024-08-02T16:30:00+00:00
1.00
- https://nn.labml.ai/RWKV/index.html
- 2024-03-17T16:30:00+00:00
+ https://nn.labml.ai/rwkv/index.html
+ 2024-08-02T16:30:00+00:00
1.00
- https://nn.labml.ai/RWKV/experiment.html
- 2024-03-17T16:30:00+00:00
+ https://nn.labml.ai/rwkv/experiment.html
+ 2024-08-02T16:30:00+00:00
1.00
@@ -1294,6 +1294,27 @@
+
+ https://nn.labml.ai/lora/gpt2.html
+ 2024-08-02T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/lora/index.html
+ 2024-08-02T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/lora/transform_hf_model.html
+ 2024-08-02T16:30:00+00:00
+ 1.00
+
+
+
https://nn.labml.ai/graphs/gat/index.html
2023-10-24T16:30:00+00:00
diff --git a/labml_nn/lora/__init__.py b/labml_nn/lora/__init__.py
index 9124ebc9..f5fc197d 100644
--- a/labml_nn/lora/__init__.py
+++ b/labml_nn/lora/__init__.py
@@ -1,5 +1,24 @@
"""
-# LoRA
+---
+title: Low-Rank Adaptation (LoRA)
+summary: >
+ Annotated implementation of RoRA from paper
+ LoRA: Low-Rank Adaptation of Large Language Models
+---
+
+# Low-Rank Adaptation (LoRA)
+
+This is an implementation of
+[Low-Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685)
+in [PyTorch](https://pytorch.org).
+
+Low-Rank Adaptation (LoRA) freezes pre-trained model weights and injects
+ trainable rank decomposition matrices into each layer of the transformer.
+ This makes it possible to efficiently fine-tune large langauge models by
+ reducing trainable parameters by a large factor.
+
+Here's [the training code](experiment.html) for training a GPT2 model with LoRA
+ on Tiny Shakespeare dataset.
"""
import torch
@@ -7,56 +26,126 @@ import torch.nn as nn
class Linear(nn.Module):
+ """
+ ## LoRA Linear Layer
+
+ LoRA linear layer adds a low-rank decomposition to the pre-trained
+ weight matrix ($W_0 \in \mathbb{R}^{d \times k}$)
+ of the linear layer.
+
+ $$W_0 + \Delta W = W_0 + BA$$
+
+ , where $B \in \mathbb{R}^{d \times r}$, $A \in \mathbb{R}^{r \times k}$,
+ and the rank $r \ll min(d, k)$.
+
+ All parameters are frozen except $A$ and $B$.
+
+ $\Delta W$ is initialized to be zero at the beginning of the training.
+
+ They multiple $\Delta W x$ by $\frac{\alpha}{r}$ where $\alpha$ is a hyper-parameter.
+ Once $\alpha$ is tuned it can be kept the same when varying $r$.
+ """
+
def __init__(self, in_features: int, out_features: int, bias: bool,
r: int, alpha: int = None):
+ """
+ :param in_features: is the number of input features of the linear layer
+ :param out_features: is the number of output features of the linear layer
+ :param bias: is a flag indicating if there is a bias parameter
+ :param r: is the rank of the decomposition $r$
+ :param alpha: is the scaling factor $\alpha$
+ """
super().__init__()
+
+ # Set $\alpha = r$ is not provided. i.e. make the scaling factor $\frac{\alpha}{r} = 1$.
if alpha is None:
alpha = r
+
+ # The pre-trained weight $W_0$
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
+ # Freeze it
self.weight.requires_grad = False
if bias:
+ # Bias parameter $b_0$ (also frozen)
self.bias = nn.Parameter(torch.empty(out_features))
self.bias.requires_grad = False
else:
+ # No bias parameter
self.bias = None
+ # scaling factor $\frac{\alpha}{r}$
self.scaling = alpha / r
+ # Matrix $A \in \mathbb{R}^{r \times k}$
self.lora_a = nn.Parameter(torch.empty((in_features, r)))
+ # Matrix $B \in \mathbb{R}^{d \times r}$, we keep $A$ and $B$ transposed
self.lora_b = nn.Parameter(torch.empty((r, out_features)))
with torch.no_grad():
+ # Initialize $A$ similar to a weight matrix in a normal linear layer
nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
+ # Initialize $B$ to $0$ so that $\Delta W = BA$ is $0$ at initialization
nn.init.zeros_(self.lora_b)
def forward(self, x: torch.Tensor):
+ # Compute $W_0 x + b_0$
result = nn.functional.linear(x, self.weight, bias=self.bias)
+ # Add $\frac{\alpha}{r} \Delta W x = \frac{\alpha}{r} BAx$
result += (x @ self.lora_a @ self.lora_b) * self.scaling
+ #
return result
class Embedding(nn.Module):
+ """
+ ## LoRA Embedding Layer
+
+ Similar to LoRA linear layer this adds a low-rank decomposition to the pre-trained
+ embedding weights matrix ($W_0 \in \mathbb{R}^{d \times k}$).
+
+ $$W_0 + \Delta W = W_0 + BA$$
+ """
+
def __init__(self, num_embeddings: int, embedding_dim: int,
r: int, alpha: int = None):
+ """
+
+ :param num_embeddings: is the number of embeddings
+ :param embedding_dim: is the number embedding dimensions
+ :param r: is the rank of the decomposition $r$
+ :param alpha: is the scaling factor $\alpha$
+ """
super().__init__()
+
+ # Set $\alpha = r$ is not provided. i.e. make the scaling factor $\frac{\alpha}{r} = 1$.
if alpha is None:
alpha = r
+ # The pre-trained embedding weights $W_0$ (frozen)
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
self.weight.requires_grad = False
+ # scaling factor $\frac{\alpha}{r}$
self.scaling = alpha / r
+ # Matrix $A \in \mathbb{R}^{r \times k}$
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
+ # Matrix $B \in \mathbb{R}^{d \times r}$
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))
with torch.no_grad():
+ # Initialize $A$ with a normal distribution
nn.init.normal_(self.lora_a)
+ # Initialize $B$ to $0$ so that $\Delta W = BA$ is $0$ at initialization
nn.init.zeros_(self.lora_b)
def forward(self, x: torch.Tensor):
+ # Compute the embeddings $W_0 \text{onehot}(x)$
result = nn.functional.embedding(x, self.weight)
+
+ # Add $\frac{\alpha}{r} \Delta W \text{onehot}(x) = \frac{\alpha}{r} BA \text{onehot}(x_$
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling
+ #
return result