RWKV docs

This commit is contained in:
Varuna Jayasiri
2024-03-17 17:45:08 +05:30
parent 7db6e92376
commit df9e1af615
5 changed files with 186 additions and 207 deletions

View File

@ -124,9 +124,7 @@
<h4><a href="conv_mixer/index.html">ConvMixer</a></h4> <h4><a href="conv_mixer/index.html">ConvMixer</a></h4>
<h4><a href="capsule_networks/index.html">Capsule Networks</a></h4> <h4><a href="capsule_networks/index.html">Capsule Networks</a></h4>
<h4><a href="unet/index.html">U-Net</a></h4> <h4><a href="unet/index.html">U-Net</a></h4>
<h4><a href="sketch_rnn/index.html">RNNs</a></h4> <h4><a href="sketch_rnn/index.html">Sketch RNN</a></h4>
<ul><li><a href="rwkv/index.html">RWKV</a> </li>
<li><a href="sketch_rnn/index.html">Sketch RNN</a></li></ul>
<h4>✨ Graph Neural Networks</h4> <h4>✨ Graph Neural Networks</h4>
<ul><li><a href="graphs/gat/index.html">Graph Attention Networks (GAT)</a> </li> <ul><li><a href="graphs/gat/index.html">Graph Attention Networks (GAT)</a> </li>
<li><a href="graphs/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></li></ul> <li><a href="graphs/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></li></ul>
@ -170,7 +168,6 @@
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li> <ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li> <li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf">ZeRO: Memory Optimizations Toward Training Trillion Parameter Models</a> </li> <li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf">ZeRO: Memory Optimizations Toward Training Trillion Parameter Models</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/RWKV.pdf">RWKV: Reinventing RNNs for the Transformer Era</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf">PaLM: Scaling Language Modeling with Pathways</a> </li> <li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf">PaLM: Scaling Language Modeling with Pathways</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf">Hierarchical Text-Conditional Image Generation with CLIP Latents</a> </li> <li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf">Hierarchical Text-Conditional Image Generation with CLIP Latents</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf">STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning</a> </li> <li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf">STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning</a> </li>

View File

@ -503,6 +503,27 @@
</url> </url>
<url>
<loc>https://nn.labml.ai/RWKV/configs.html</loc>
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/RWKV/index.html</loc>
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/RWKV/experiment.html</loc>
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url> <url>
<loc>https://nn.labml.ai/cfr/infoset_saver.html</loc> <loc>https://nn.labml.ai/cfr/infoset_saver.html</loc>
<lastmod>2021-06-21T16:30:00+00:00</lastmod> <lastmod>2021-06-21T16:30:00+00:00</lastmod>

View File

@ -9,60 +9,55 @@ summary: >
# Receptance Weighted Key Value (RWKV) # Receptance Weighted Key Value (RWKV)
##TODO: make colab ?
This is a tutorial/implementation of RWKV This is a tutorial/implementation of RWKV
from paper [RWKV: Reinventing RNNs for the Transformer Era](https://arxiv.org/pdf/2305.13048.pdf) from paper [RWKV: Reinventing RNNs for the Transformer Era](https://arxiv.org/pdf/2305.13048.pdf)
in [PyTorch](https://pytorch.org/). in [PyTorch](https://pytorch.org/).
Full definition of a RWKV Language Model, all of it in this single file. Full definition of a RWKV Language Model, all of it in this single file.
References: References:
1) the official RWKV PyTorch implementation released by Bo Peng: 1) [the official RWKV PyTorch implementation released by Bo Peng](https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py)
https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py 2) [huggingface/transformers PyTorch implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py)
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
""" """
import math,time
import os
import inspect
from dataclasses import dataclass
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from labml_helpers.module import Module from labml_helpers.module import Module
PREV_X_TIME = 0 PREV_X_TIME = 0
NUM_STATE = 1 NUM_STATE = 1
DEN_STATE = 2 DEN_STATE = 2
MAX_STATE = 3 MAX_STATE = 3
PREV_X_CHANNEL = 4 PREV_X_CHANNEL = 4
"""
## Layernorm with bias
"""
class LayerNorm(Module): class LayerNorm(Module):
"""
### Layer normalization with bias
"""
def __init__(self, ndim, bias): def __init__(self, ndim, bias):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(ndim)) self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input): def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
"""
# L2 loss wrapper
https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21
"""
class L2Wrap(torch.autograd.Function): class L2Wrap(torch.autograd.Function):
"""
### L2 loss wrapper
[ref](https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21)
"""
@staticmethod @staticmethod
def forward(ctx, loss, y): def forward(ctx, loss, y):
ctx.save_for_backward(y) ctx.save_for_backward(y)
return loss return loss
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
y = ctx.saved_tensors[0] y = ctx.saved_tensors[0]
@ -71,118 +66,110 @@ class L2Wrap(torch.autograd.Function):
maxx, ids = torch.max(y, -1, keepdim=True) maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y) gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor) gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy) return grad_output, gy
class ChannelMixing(Module): class ChannelMixing(Module):
""" """
## Channel Mixing ### Channel Mixing
""" """
def __init__(self,config,layer_id):
def __init__(self, config, layer_id):
super().__init__() super().__init__()
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
# token shifting # token shifting
self.layer_id = layer_id self.layer_id = layer_id
n_embd = config.n_embd n_embd = config.n_embd
intermediate_size = ( intermediate_size = (
config.intermediate_size if config.intermediate_size is not None else 4 * n_embd config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
) )
## Learnable Matrix # Learnable Matrix
self.key_proj = nn.Linear(n_embd,intermediate_size,bias=False) self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
self.value_proj = nn.Linear(intermediate_size,n_embd,bias=False) self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
self.receptance_proj = nn.Linear(n_embd,n_embd,bias=False) self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)
## Learnable Vector # Learnable Vector
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd)) self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd)) self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
def forward(self,x,state=None): def forward(self, x, state=None):
"""
# x = (Batch,Time,Channel) # x = (Batch,Time,Channel)
"""
if state is not None: if state is not None:
prev_x = state[self.layer_id,:,[PREV_X_CHANNEL],:] prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
state[self.layer_id,:,[PREV_X_CHANNEL],:] = x state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
else: else:
prev_x = self.time_shift(x) prev_x = self.time_shift(x)
""" # $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
"""
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance) receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
receptance = self.receptance_proj(receptance) receptance = self.receptance_proj(receptance)
""" # $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
"""
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key) key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
key = self.key_proj(key) key = self.key_proj(key)
""" # $V_t=W_v \cdot max(k_t,0)^2$
### $V_t=W_v \cdot max(k_t,0)^2$
"""
value = self.value_proj(torch.square(torch.relu(key))) value = self.value_proj(torch.square(torch.relu(key)))
""" # $o_t=\sigma(r_t) \odot v_t$
### $o_t=\sigma(r_t) \odot v_t$
"""
out = F.sigmoid(receptance) * value out = F.sigmoid(receptance) * value
return out, state return out, state
"""
## Time Mixing
"""
class TimeMixing(Module): class TimeMixing(Module):
def __init__(self,config,layer_id): """
### Time Mixing
"""
def __init__(self, config, layer_id):
super().__init__() super().__init__()
self.config = config self.config = config
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.layer_id = layer_id self.layer_id = layer_id
n_embd = config.n_embd n_embd = config.n_embd
attn_sz = n_embd attn_sz = n_embd
## learnable matrix # learnable matrix
self.key_proj = nn.Linear(n_embd, attn_sz, bias=False) self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
self.value_proj = nn.Linear(n_embd, attn_sz, bias=False) self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False) self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
self.output_proj = nn.Linear(attn_sz, n_embd, bias=False) self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
## learnable vector # learnable vector
self.time_decay = nn.Parameter(torch.empty(attn_sz)) self.time_decay = nn.Parameter(torch.empty(attn_sz))
self.time_first = nn.Parameter(torch.empty(attn_sz)) self.time_first = nn.Parameter(torch.empty(attn_sz))
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd)) self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd)) self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd)) self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
def forward(self,x,state=None): def forward(self, x, state=None):
# x = (Batch,Time,Channel) """
x = (Batch,Time,Channel)
"""
if state is not None: if state is not None:
prev_x = state[self.layer_id,:,[PREV_X_TIME],:] prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
state[self.layer_id,:,[PREV_X_TIME],:] = x state[self.layer_id, :, [PREV_X_TIME], :] = x
else: else:
prev_x = self.time_shift(x) prev_x = self.time_shift(x)
""" # $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
"""
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance) receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
receptance = self.receptance_proj(receptance) receptance = self.receptance_proj(receptance)
""" # $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
"""
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key) key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
key = self.key_proj(key) key = self.key_proj(key)
""" # $v_t=W_v \cdot (\mu_v x_t + (1-\mu_v)x_{t-1})$
### $v_t=W_v \cdot (\mu_v x_t + (1-\mu_v)x_{t-1})$
"""
value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value) value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
value = self.value_proj(value) value = self.value_proj(value)
""" # WKV calculation
## WKV calculation
"""
_, seq_length, _ = key.size() _, seq_length, _ = key.size()
output = torch.zeros_like(key) output = torch.zeros_like(key)
@ -191,9 +178,9 @@ class TimeMixing(Module):
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32) den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38 max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
else: else:
num_state = state[self.layer_id,:,NUM_STATE,:] num_state = state[self.layer_id, :, NUM_STATE, :]
den_state = state[self.layer_id,:,DEN_STATE,:] den_state = state[self.layer_id, :, DEN_STATE, :]
max_state = state[self.layer_id,:,MAX_STATE,:] max_state = state[self.layer_id, :, MAX_STATE, :]
time_decay = -torch.exp(self.time_decay) time_decay = -torch.exp(self.time_decay)
@ -201,9 +188,7 @@ class TimeMixing(Module):
current_key = key[:, current_index].float() current_key = key[:, current_index].float()
current_value = value[:, current_index] current_value = value[:, current_index]
""" # $wkv_t=\frac{\sum^{t-1}_{i=1}d^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum^{t-1}_{i=1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}$
### $wkv_t=\frac{\sum^{t-1}_{i=1}d^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum^{t-1}_{i=1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}$
"""
max_for_output = torch.maximum(max_state, current_key + self.time_first) max_for_output = torch.maximum(max_state, current_key + self.time_first)
e1 = torch.exp(max_state - max_for_output) e1 = torch.exp(max_state - max_for_output)
e2 = torch.exp(current_key + self.time_first - max_for_output) e2 = torch.exp(current_key + self.time_first - max_for_output)
@ -218,110 +203,100 @@ class TimeMixing(Module):
num_state = e1 * num_state + e2 * current_value num_state = e1 * num_state + e2 * current_value
den_state = e1 * den_state + e2 den_state = e1 * den_state + e2
max_state = max_for_state max_state = max_for_state
""" # update states
### update states state[self.layer_id, :, NUM_STATE, :] = num_state
""" state[self.layer_id, :, DEN_STATE, :] = den_state
state[self.layer_id,:,NUM_STATE,:] = num_state state[self.layer_id, :, MAX_STATE, :] = max_state
state[self.layer_id,:,DEN_STATE,:] = den_state wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
state[self.layer_id,:,MAX_STATE,:] = max_state state=state)
wkv, state = self.wkv_function(key,value,use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,state=state)
# $o_t=W_o \cdot (\sigma(r_t) \odot wkv_t)$
"""
### $o_t=W_o \cdot (\sigma(r_t) \odot wkv_t)$
"""
rwkv = F.sigmoid(receptance) * wkv rwkv = F.sigmoid(receptance) * wkv
rwkv = self.output_proj(rwkv) rwkv = self.output_proj(rwkv)
return rwkv, state return rwkv, state
"""
## RWKV block element
"""
class Block(Module):
def __init__(self, config,layer_id): class Block(Module):
"""
## RWKV block element
"""
def __init__(self, config, layer_id):
super().__init__() super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = TimeMixing(config,layer_id) self.attn = TimeMixing(config, layer_id)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.ffn = ChannelMixing(config,layer_id) self.ffn = ChannelMixing(config, layer_id)
def forward(self, x, state = None): def forward(self, x, state=None):
# state: [batch_size, 5 , n_embd] # state: [batch_size, 5 , n_embd]
"""
## time mixing # time mixing
"""
residual = x residual = x
x,state = self.attn(self.ln_1(x),state=state) x, state = self.attn(self.ln_1(x), state=state)
x = x + residual x = x + residual
"""
## channel mixing # channel mixing
"""
residual = x residual = x
x, state = self.ffn(self.ln_2(x),state=state) x, state = self.ffn(self.ln_2(x), state=state)
x = x + residual x = x + residual
return x, state return x, state
class RWKV(Module): class RWKV(Module):
def __init__(self, config,lr_init=0.0008): """
## RWKV
"""
def __init__(self, config, lr_init=0.0008):
super().__init__() super().__init__()
assert config.vocab_size is not None assert config.vocab_size is not None
assert config.block_size is not None assert config.block_size is not None
self.config = config self.config = config
self.lr_init = lr_init ## used to initialize embedding parameters self.lr_init = lr_init ## used to initialize embedding parameters
self.n_layer = config.n_layer self.n_layer = config.n_layer
self.n_embd = config.n_embd self.n_embd = config.n_embd
"""
## Initiate model layers
"""
self.rwkv = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
ln_p = LayerNorm(config.n_embd, bias=config.bias),
h = nn.ModuleList([Block(config,layer_id) for layer_id in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
))
"""
## Output linear layer
"""
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initiate model layers
self.rwkv = nn.ModuleDict(dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
ln_p=LayerNorm(config.n_embd, bias=config.bias),
h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
ln_f=LayerNorm(config.n_embd, bias=config.bias),
))
# Output linear layer
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
def forward(self, idx, targets=None, state=None, return_state=False): def forward(self, idx, targets=None, state=None, return_state=False):
b, t = idx.size() b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
""" # Embedding Layer
## Embedding Layer
"""
x = self.rwkv.wte(idx) x = self.rwkv.wte(idx)
"""
## Layer Norm # Layer Norm
"""
x = self.rwkv.ln_p(x) x = self.rwkv.ln_p(x)
"""
## RWKV Blocks # RWKV Blocks
""" for block_idx, block in enumerate(self.rwkv.h):
for block_idx,block in enumerate(self.rwkv.h): x, state = block(x, state)
x, state = block(x,state)
x = self.rwkv.ln_f(x) x = self.rwkv.ln_f(x)
""" # Logit Layer and loss Function (for training)
## Logit Layer and loss Function (for training)
"""
if targets is not None: if targets is not None:
# if we are given some desired targets also calculate the loss # if we are given some desired targets also calculate the loss
logits = self.lm_head(x) logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
if self.training: if self.training:
loss = L2Wrap.apply(loss,logits) loss = L2Wrap.apply(loss, logits)
else: else:
# inference-time mini-optimization: only forward the lm_head on the very last position # inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None loss = None
"""
## Return Logits and loss # Return Logits and loss
"""
if return_state: if return_state:
return logits, loss, state return logits, loss, state
else: else:

View File

@ -1,15 +1,8 @@
import copy from labml.configs import BaseConfigs
import torch.nn as nn
from labml.configs import BaseConfigs, option, calculate, aggregate
from labml_helpers.module import Module
class RWKVConfigs(BaseConfigs): class RWKVConfigs(BaseConfigs):
""" """
<a id="TransformerConfigs"></a>
## Transformer Configurations ## Transformer Configurations
This defines configurations for a transformer. This defines configurations for a transformer.

View File

@ -1,22 +1,16 @@
import math,time
import os
import inspect import inspect
from dataclasses import dataclass import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from labml_nn.RWKV.configs import RWKVConfigs
import numpy as np
from __init__ import RWKV
from __init__ import TimeMixing
from labml import experiment from labml import experiment
from labml.configs import option from labml.configs import option
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.optimizers.configs import OptimizerConfigs
from configs import RWKVConfigs
from __init__ import RWKV
from __init__ import TimeMixing
from __init__ import ChannelMixing
from contextlib import nullcontext
class Configs(NLPAutoRegressionConfigs): class Configs(NLPAutoRegressionConfigs):
""" """
@ -33,7 +27,7 @@ class Configs(NLPAutoRegressionConfigs):
# number of warmup iterations # number of warmup iterations
warmup_iters: int = 2000 warmup_iters: int = 2000
# total number of training iterations # total number of training iterations
max_iters: int = 600000 max_iters: int = 600000
# weight decay # weight decay
weight_decay: float = 1e-1 weight_decay: float = 1e-1
# Custom optimizer # Custom optimizer
@ -41,6 +35,7 @@ class Configs(NLPAutoRegressionConfigs):
beta2: float = 0.95 beta2: float = 0.95
optimizer = 'rwkv_optimizer' optimizer = 'rwkv_optimizer'
@option(Configs.rwkv, 'RWKV') @option(Configs.rwkv, 'RWKV')
def _rwkv_configs(c: Configs): def _rwkv_configs(c: Configs):
""" """
@ -58,14 +53,13 @@ def _rwkv_configs(c: Configs):
def _init_weights(module, rwkv: RWKVConfigs): def _init_weights(module, rwkv: RWKVConfigs):
# initialize Vector Parameters in TimeMixing
## initialize Vector Parameters in TimeMixing if isinstance(module, TimeMixing):
if isinstance(module,TimeMixing):
layer_id = module.layer_id layer_id = module.layer_id
n_layer = module.n_layer n_layer = module.n_layer
n_embd = module.n_embd n_embd = module.n_embd
attn_sz = n_embd attn_sz = n_embd
with torch.no_grad(): with torch.no_grad():
ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1 ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0 ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
@ -83,7 +77,6 @@ def _init_weights(module, rwkv: RWKVConfigs):
module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
@option(Configs.model) @option(Configs.model)
@ -101,30 +94,30 @@ def _model(c: Configs):
@option(NLPAutoRegressionConfigs.optimizer) @option(NLPAutoRegressionConfigs.optimizer)
def _configure_optimizers(c: NLPAutoRegressionConfigs): def _configure_optimizers(c: NLPAutoRegressionConfigs):
# start with all of the candidate parameters # start with all of the candidate parameters
param_dict = {pn: p for pn, p in c.model.named_parameters()} param_dict = {pn: p for pn, p in c.model.named_parameters()}
# filter out those that do not require grad # filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [ optim_groups = [
{'params': decay_params, 'weight_decay': c.weight_decay}, {'params': decay_params, 'weight_decay': c.weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0} {'params': nodecay_params, 'weight_decay': 0.0}
] ]
num_decay_params = sum(p.numel() for p in decay_params) num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available # Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and c.device_type == 'cuda' use_fused = fused_available and c.device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict() extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args) optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
print(f"using fused AdamW: {use_fused}") print(f"using fused AdamW: {use_fused}")
return optimizer return optimizer
def main(): def main():
@ -154,11 +147,11 @@ def main():
# per epoch # per epoch
'inner_iterations': 10, 'inner_iterations': 10,
'rwkv.block_size' : 1024, 'rwkv.block_size': 1024,
# model # model
'rwkv.n_layer' : 12, 'rwkv.n_layer': 12,
'rwkv.n_heads' : 12, 'rwkv.n_heads': 12,
'rwkv.n_embd' : 768 'rwkv.n_embd': 768
}) })
print(conf.model) print(conf.model)