1import torch
2import torch.nn as nn
3from labml_nn.lora import Linear, Embedding
6class FFN(nn.Module):
7    def __init__(self, dim: int, n_embed: int, r: int):
8        super().__init__()

lin1

10        self.c_fc = Linear(n_embed, dim, r=r, bias=True)

lin2

12        self.c_proj = Linear(dim, n_embed, r=r, bias=True)
13        self.act = nn.functional.gelu
15    def forward(self, hidden_states):
16        hidden_states = self.c_fc(hidden_states)
17        hidden_states = self.act(hidden_states)
18        hidden_states = self.c_proj(hidden_states)
19        return hidden_states
22class MultiHeadAttention(nn.Module):
23    def __init__(self, n_embed: int, r: int):
24        super().__init__()
25        self.embed_dim = n_embed
26        self.num_heads = n_embed
27        self.head_dim = self.embed_dim // self.num_heads
28        self.split_size = self.embed_dim

qkv

31        self.c_att = Linear(n_embed, n_embed * 3, r=r, bias=True)

out

33        self.c_proj = Linear(n_embed, n_embed, r=r, bias=True)

Splits hidden_size dim into attn_head_size and num_heads

35    def _split_heads(self, tensor, num_heads, attn_head_size):
39        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
40        tensor = tensor.view(new_shape)
41        return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
43    def forward(self, hidden_states):
44        batch_size, seq_length, _ = hidden_states.size()
45
46        query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)
47
48        query = self._split_heads(query, self.num_heads, self.head_dim)
49        key = self._split_heads(key, self.num_heads, self.head_dim)
50        value = self._split_heads(value, self.num_heads, self.head_dim)
51
52        attn_output = torch.nn.functional.scaled_dot_product_attention(
53            query,
54            key,
55            value,
56            attn_mask=None,
57            dropout_p=0.0,
58            is_causal=True,  # for the triangular mask
59        )
60
61        attn_output = attn_output.transpose(1, 2).contiguous()
62        attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
63
64        attn_output = self.c_proj(attn_output)
65
66        return attn_output
69class Block(nn.Module):
70    def __init__(self, n_embed: int, layer_norm_epsilon: float, r: int):
71        super().__init__()
72        self.pre_norm = nn.LayerNorm(n_embed, eps=layer_norm_epsilon)
73        self.attn = MultiHeadAttention(n_embed, r)
74        self.post_norm = nn.LayerNorm(n_embed, eps=layer_norm_epsilon)
75        self.ffn = FFN(n_embed * 4, n_embed, r)
77    def forward(self, hidden_states):
78        residual = hidden_states
79        hidden_states = self.pre_norm(hidden_states)
80
81        attn_output = self.attn(hidden_states)
82
83        hidden_states = attn_output + residual
84        residual = hidden_states
85        hidden_states = self.post_norm(hidden_states)
86        feed_forward_output = self.ffn(hidden_states)
87        hidden_states = feed_forward_output + residual
88
89        return hidden_states
92class GPTModel(nn.Module):
93    def __init__(self, layer_norm_epsilon: float, n_embd: int, n_layer: int, n_positions: int,
94                 vocab_size: int, r: int):
95        super().__init__()
96
97        self.token_embedding = Embedding(vocab_size, n_embd, r=r)
98        self.position_embedding = Embedding(n_positions, n_embd, r=r)
99
100        self.blocks = nn.ModuleList([Block(n_embd, layer_norm_epsilon, r=r)
101                                     for _ in range(n_layer)])
102
103        self.final_norm = nn.LayerNorm(n_embd, eps=layer_norm_epsilon)
104
105        self.lm_head = Linear(n_embd, vocab_size, r=r, bias=False)
  • input_ids has shape [batch_size, seq_len]
107    def forward(self, input_ids: torch.Tensor):
111        batch_size, seq_len = input_ids.shape

Get token embeddings

114        token_embeddings = self.token_embedding(input_ids)

Get position ids

116        position_ids = torch.arange(seq_len, device=input_ids.device)[None, :]

Get position embeddings

118        position_embeddings = self.position_embedding(position_ids)

Add position embeddings

121        x = token_embeddings + position_embeddings

Run through transformer blocks

124        for block in self.blocks:
125            x = block(x)

Final normalization

128        x = self.final_norm(x)

Get logits from projection layer

130        return self.lm_head(x)