Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.
14import torch
15from labml import lab, monit, tracker
16from labml.configs import BaseConfigs, option
17from labml.utils.download import download_file
18from labml_helpers.device import DeviceConfigs
19from torch.optim import Adam
20from torch.utils.data import DataLoader, TensorDataset
21from transformers import AutoTokenizer, AutoModelForCausalLM
22from labml_nn.lora.gpt2 import GPTModelThe default configs can and will be over-ridden when we start the experiment
25class Trainer(BaseConfigs):31 device: torch.device = DeviceConfigs()GPT-2 configs
34 layer_norm_epsilon: float = 1e-05
35 n_embed: int = 768
36 n_layer: int = 12
37 n_positions: int = 1024
38 vocab_size: int = 50257Training configs
41 epochs: int = 10
42 batch_size: int = 32
43 learning_rate: float = 1e-4
44 context_len: int = 512LoRA rank
47 lora_r: int = 32Dataset
50 text: TensorDataset = "tiny_shakespeare"
51 tokenizer = AutoTokenizer.from_pretrained("gpt2")
52 model: GPTModel
53 optimizer: torch.optim.Adam
54 criterion = torch.nn.CrossEntropyLoss()
55 data_loader: DataLoader57 def _load_pretrained_weights(self):Load the huggingface model and get the parameters
63 hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
64 state_dict = hf_model.state_dict()Transformer embedding and prediction layer parameter mapping (hf: ours
)
67 mapping = {
68 'transformer.wte.weight': 'token_embedding.weight',
69 'transformer.wpe.weight': 'position_embedding.weight',
70 'transformer.ln_f.weight': 'final_norm.weight',
71 'transformer.ln_f.bias': 'final_norm.bias',
72 'lm_head.weight': 'lm_head.weight'
73 }Mapping (hf: ours
) of decoder layers
76 for i in range(12):
77 mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
78 mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
79 mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
80 mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
81 mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
82 mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
83 mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
84 mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
85 mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
86 mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
87 mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
88 mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'Move the parameters based on mapping
91 new_state_dict = {}
92 for old_key, new_key in mapping.items():
93 if old_key in state_dict:
94 new_state_dict[new_key] = state_dict[old_key]GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers
97 convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
98 [f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
99 [f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
100 [f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
101
102 for layer in convo_layers:
103 new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)Load out model
106 self.model.load_state_dict(new_state_dict, strict=False) # state dict does not have lora weights108 def initialize(self):Initialize the model
113 self.model = GPTModel(
114 layer_norm_epsilon=self.layer_norm_epsilon,
115 n_embd=self.n_embed,
116 n_layer=self.n_layer,
117 n_positions=self.n_positions,
118 vocab_size=self.vocab_size,
119 r=self.lora_r,
120 )
121 self.model.to(self.device)Load pre-trained model weights
123 self._load_pretrained_weights()Initialize the optimizer
126 self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)Initialize the data loader
129 self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)131 def run(self):136 for _ in monit.loop(self.epochs):
137 for i, batch in monit.enum('Train', self.data_loader):
138 inputs = batch[0]
139 inputs = inputs.to(self.device)
140 labels = inputs.clone()
141
142 outputs = self.model(inputs)
143
144 shift_logits = outputs[..., :-1, :]
145 shift_labels = labels[..., 1:]
146
147 loss = self.criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
148
149 self.optimizer.zero_grad()
150 loss.backward()
151 self.optimizer.step()
152
153 tracker.add({'loss': loss})
154
155 tracker.save()
156 tracker.add_global_step()
157 tracker.new_line()160@option(Trainer.text)
161def tiny_shakespeare(c: Trainer):167 path = lab.get_data_path() / 'tiny_shakespeare.txt'
168 if not path.exists():
169 download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
170 with open(path, 'r', encoding='utf-8') as f:
171 text = f.read()
172
173 tokens = c.tokenizer.encode(text)
174 num_batches = len(tokens) // (c.batch_size * c.context_len)
175 tokens = tokens[:num_batches * c.batch_size * c.context_len]
176 input_ids = torch.tensor(tokens).view(-1, c.context_len)
177 return TensorDataset(input_ids)