diff --git a/labml_nn/lora/experiment.py b/labml_nn/lora/experiment.py index b64da063..163ce0c2 100644 --- a/labml_nn/lora/experiment.py +++ b/labml_nn/lora/experiment.py @@ -76,16 +76,16 @@ class Trainer(BaseConfigs): for i in range(12): mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight' mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias' - mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight' - mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias' - mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight' - mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias' + mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight' + mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias' + mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight' + mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias' mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight' mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias' - mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight' - mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias' - mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight' - mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias' + mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight' + mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias' + mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight' + mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.linear_out.bias' # Move the parameters based on mapping new_state_dict = {} @@ -94,10 +94,10 @@ class Trainer(BaseConfigs): new_state_dict[new_key] = state_dict[old_key] # GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers - convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] + - [f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] + - [f'blocks.{i}.attn.c_att.weight' for i in range(12)] + - [f'blocks.{i}.attn.c_proj.weight' for i in range(12)]) + convo_layers = ([f'blocks.{i}.ffn.linear_in.weight' for i in range(12)] + + [f'blocks.{i}.ffn.linear_out.weight' for i in range(12)] + + [f'blocks.{i}.attn.qkv_projection.weight' for i in range(12)] + + [f'blocks.{i}.attn.output_projection.weight' for i in range(12)]) for layer in convo_layers: new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1) diff --git a/labml_nn/lora/gpt2.py b/labml_nn/lora/gpt2.py index 284213ab..467fd757 100644 --- a/labml_nn/lora/gpt2.py +++ b/labml_nn/lora/gpt2.py @@ -6,16 +6,14 @@ from labml_nn.lora import Linear, Embedding class FFN(nn.Module): def __init__(self, dim: int, n_embed: int, r: int): super().__init__() - # lin1 - self.c_fc = Linear(n_embed, dim, r=r, bias=True) - # lin2 - self.c_proj = Linear(dim, n_embed, r=r, bias=True) + self.linear_in = Linear(n_embed, dim, r=r, bias=True) + self.linear_out = Linear(dim, n_embed, r=r, bias=True) self.act = nn.functional.gelu def forward(self, hidden_states): - hidden_states = self.c_fc(hidden_states) + hidden_states = self.linear_in(hidden_states) hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) + hidden_states = self.linear_out(hidden_states) return hidden_states @@ -27,10 +25,10 @@ class MultiHeadAttention(nn.Module): self.head_dim = self.embed_dim // self.num_heads self.split_size = self.embed_dim - # qkv - self.c_att = Linear(n_embed, n_embed * 3, r=r, bias=True) + # query key value + self.qkv_projection = Linear(n_embed, n_embed * 3, r=r, bias=True) # out - self.c_proj = Linear(n_embed, n_embed, r=r, bias=True) + self.output_projection = Linear(n_embed, n_embed, r=r, bias=True) def _split_heads(self, tensor, num_heads, attn_head_size): """ @@ -43,7 +41,7 @@ class MultiHeadAttention(nn.Module): def forward(self, hidden_states): batch_size, seq_length, _ = hidden_states.size() - query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2) + query, key, value = self.qkv_projection(hidden_states).split(self.split_size, dim=2) query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) @@ -61,7 +59,7 @@ class MultiHeadAttention(nn.Module): attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, seq_length, self.embed_dim) - attn_output = self.c_proj(attn_output) + attn_output = self.output_projection(attn_output) return attn_output