mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
rename layers
This commit is contained in:
@ -6,16 +6,14 @@ from labml_nn.lora import Linear, Embedding
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, dim: int, n_embed: int, r: int):
|
||||
super().__init__()
|
||||
# lin1
|
||||
self.c_fc = Linear(n_embed, dim, r=r, bias=True)
|
||||
# lin2
|
||||
self.c_proj = Linear(dim, n_embed, r=r, bias=True)
|
||||
self.linear_in = Linear(n_embed, dim, r=r, bias=True)
|
||||
self.linear_out = Linear(dim, n_embed, r=r, bias=True)
|
||||
self.act = nn.functional.gelu
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.linear_in(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = self.linear_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -27,10 +25,10 @@ class MultiHeadAttention(nn.Module):
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.split_size = self.embed_dim
|
||||
|
||||
# qkv
|
||||
self.c_att = Linear(n_embed, n_embed * 3, r=r, bias=True)
|
||||
# query key value
|
||||
self.qkv_projection = Linear(n_embed, n_embed * 3, r=r, bias=True)
|
||||
# out
|
||||
self.c_proj = Linear(n_embed, n_embed, r=r, bias=True)
|
||||
self.output_projection = Linear(n_embed, n_embed, r=r, bias=True)
|
||||
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
@ -43,7 +41,7 @@ class MultiHeadAttention(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
batch_size, seq_length, _ = hidden_states.size()
|
||||
|
||||
query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)
|
||||
query, key, value = self.qkv_projection(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
@ -61,7 +59,7 @@ class MultiHeadAttention(nn.Module):
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.output_projection(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
Reference in New Issue
Block a user