1import torch
2import torch.nn as nn
3from labml_nn.lora import Linear, Embedding6class FFN(nn.Module):7 def __init__(self, dim: int, n_embed: int, r: int):
8 super().__init__()lin1
10 self.c_fc = Linear(n_embed, dim, r=r, bias=True)lin2
12 self.c_proj = Linear(dim, n_embed, r=r, bias=True)
13 self.act = nn.functional.gelu15 def forward(self, hidden_states):
16 hidden_states = self.c_fc(hidden_states)
17 hidden_states = self.act(hidden_states)
18 hidden_states = self.c_proj(hidden_states)
19 return hidden_states22class MultiHeadAttention(nn.Module):23 def __init__(self, n_embed: int, r: int):
24 super().__init__()
25 self.embed_dim = n_embed
26 self.num_heads = n_embed
27 self.head_dim = self.embed_dim // self.num_heads
28 self.split_size = self.embed_dimqkv
31 self.c_att = Linear(n_embed, n_embed * 3, r=r, bias=True)out
33 self.c_proj = Linear(n_embed, n_embed, r=r, bias=True)Splits hidden_size dim into attn_head_size and num_heads
35 def _split_heads(self, tensor, num_heads, attn_head_size):39 new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
40 tensor = tensor.view(new_shape)
41 return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)43 def forward(self, hidden_states):
44 batch_size, seq_length, _ = hidden_states.size()
45
46 query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)
47
48 query = self._split_heads(query, self.num_heads, self.head_dim)
49 key = self._split_heads(key, self.num_heads, self.head_dim)
50 value = self._split_heads(value, self.num_heads, self.head_dim)
51
52 attn_output = torch.nn.functional.scaled_dot_product_attention(
53 query,
54 key,
55 value,
56 attn_mask=None,
57 dropout_p=0.0,
58 is_causal=True, # for the triangular mask
59 )
60
61 attn_output = attn_output.transpose(1, 2).contiguous()
62 attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
63
64 attn_output = self.c_proj(attn_output)
65
66 return attn_output69class Block(nn.Module):70 def __init__(self, n_embed: int, layer_norm_epsilon: float, r: int):
71 super().__init__()
72 self.pre_norm = nn.LayerNorm(n_embed, eps=layer_norm_epsilon)
73 self.attn = MultiHeadAttention(n_embed, r)
74 self.post_norm = nn.LayerNorm(n_embed, eps=layer_norm_epsilon)
75 self.ffn = FFN(n_embed * 4, n_embed, r)77 def forward(self, hidden_states):
78 residual = hidden_states
79 hidden_states = self.pre_norm(hidden_states)
80
81 attn_output = self.attn(hidden_states)
82
83 hidden_states = attn_output + residual
84 residual = hidden_states
85 hidden_states = self.post_norm(hidden_states)
86 feed_forward_output = self.ffn(hidden_states)
87 hidden_states = feed_forward_output + residual
88
89 return hidden_states92class GPTModel(nn.Module):93 def __init__(self, layer_norm_epsilon: float, n_embd: int, n_layer: int, n_positions: int,
94 vocab_size: int, r: int):
95 super().__init__()
96
97 self.token_embedding = Embedding(vocab_size, n_embd, r=r)
98 self.position_embedding = Embedding(n_positions, n_embd, r=r)
99
100 self.blocks = nn.ModuleList([Block(n_embd, layer_norm_epsilon, r=r)
101 for _ in range(n_layer)])
102
103 self.final_norm = nn.LayerNorm(n_embd, eps=layer_norm_epsilon)
104
105 self.lm_head = Linear(n_embd, vocab_size, r=r, bias=False)input_ids
has shape [batch_size, seq_len]
107 def forward(self, input_ids: torch.Tensor):111 batch_size, seq_len = input_ids.shapeGet token embeddings
114 token_embeddings = self.token_embedding(input_ids)Get position ids
116 position_ids = torch.arange(seq_len, device=input_ids.device)[None, :]Get position embeddings
118 position_embeddings = self.position_embedding(position_ids)Add position embeddings
121 x = token_embeddings + position_embeddingsRun through transformer blocks
124 for block in self.blocks:
125 x = block(x)Final normalization
128 x = self.final_norm(x)Get logits from projection layer
130 return self.lm_head(x)