Files
Varuna Jayasiri dc4762161d Clean up LoRA
2024-08-02 15:32:02 +05:30

170 lines
5.3 KiB
Python

import inspect
import math
import torch
import torch.nn as nn
from labml_nn.rwkv.configs import RWKVConfigs
from labml_nn.rwkv import RWKV
from labml_nn.rwkv import TimeMixing
from labml import experiment
from labml.configs import option
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
class Configs(NLPAutoRegressionConfigs):
"""
## Configurations
This inherits from
[`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs)
"""
# RWKV model
model: RWKV
rwkv: RWKVConfigs
# number of warmup iterations
warmup_iters: int = 2000
# total number of training iterations
max_iters: int = 600000
# weight decay
weight_decay: float = 1e-1
# Custom optimizer
beta1: float = 0.9
beta2: float = 0.95
optimizer = 'rwkv_optimizer'
@option(Configs.rwkv, 'RWKV')
def _rwkv_configs(c: Configs):
"""
### RWKV configurations
"""
# We use our
# [configurable RWKV implementation](../configs.html#RWKVConfigs)
conf = RWKVConfigs()
# Set the vocabulary sizes for embeddings and generating logits
conf.n_src_vocab = c.n_tokens
conf.n_tgt_vocab = c.n_tokens
return conf
def _init_weights(module, rwkv: RWKVConfigs):
# initialize Vector Parameters in TimeMixing
if isinstance(module, TimeMixing):
layer_id = module.layer_id
n_layer = module.n_layer
n_embd = module.n_embd
attn_sz = n_embd
with torch.no_grad():
ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
ddd = torch.ones(1, 1, n_embd)
for i in range(n_embd):
ddd[0, 0, i] = i / n_embd
decay_speed = torch.ones(attn_sz)
for h in range(attn_sz):
decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
module.time_decay = nn.Parameter(decay_speed)
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
module.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
@option(Configs.model)
def _model(c: Configs):
"""
Create RWKV model and initialize weights
"""
m = RWKV(c.rwkv).to(c.device)
# Apply custom weight initialization
m.apply(_init_weights, c.rwkv)
return m
@option(NLPAutoRegressionConfigs.optimizer)
def _configure_optimizers(c: NLPAutoRegressionConfigs):
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in c.model.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': c.weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and c.device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer
def main():
# Create experiment
experiment.create(name="RWKV")
# Create configs
conf = Configs()
print(conf.model)
# Override configurations
experiment.configs(conf, {
# Use character level tokenizer
'tokenizer': 'character',
# Prompt separator is blank
'prompt_separator': '',
# Starting prompt for sampling
'prompt': 'It is ',
# Use Tiny Shakespeare dataset
'text': 'tiny_shakespeare',
# Use a context size of $128$
'seq_len': 128,
# Train for $32$ epochs
'epochs': 32,
# Batch size $128$
'batch_size': 128,
# Switch between training and validation for $10$ times
# per epoch
'inner_iterations': 10,
'rwkv.block_size': 1024,
# model
'rwkv.n_layer': 12,
'rwkv.n_heads': 12,
'rwkv.n_embd': 768
})
print(conf.model)
# Set models for saving and loading
experiment.add_pytorch_models({'model': conf.model})
# Start the experiment
with experiment.start():
# Run training
conf.run()
#
if __name__ == '__main__':
main()