RWKV docs

This commit is contained in:
Varuna Jayasiri
2024-03-17 17:45:08 +05:30
parent 7db6e92376
commit df9e1af615
5 changed files with 186 additions and 207 deletions

View File

@ -124,9 +124,7 @@
<h4><a href="conv_mixer/index.html">ConvMixer</a></h4>
<h4><a href="capsule_networks/index.html">Capsule Networks</a></h4>
<h4><a href="unet/index.html">U-Net</a></h4>
<h4><a href="sketch_rnn/index.html">RNNs</a></h4>
<ul><li><a href="rwkv/index.html">RWKV</a> </li>
<li><a href="sketch_rnn/index.html">Sketch RNN</a></li></ul>
<h4><a href="sketch_rnn/index.html">Sketch RNN</a></h4>
<h4>✨ Graph Neural Networks</h4>
<ul><li><a href="graphs/gat/index.html">Graph Attention Networks (GAT)</a> </li>
<li><a href="graphs/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></li></ul>
@ -170,7 +168,6 @@
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf">ZeRO: Memory Optimizations Toward Training Trillion Parameter Models</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/RWKV.pdf">RWKV: Reinventing RNNs for the Transformer Era</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf">PaLM: Scaling Language Modeling with Pathways</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf">Hierarchical Text-Conditional Image Generation with CLIP Latents</a> </li>
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf">STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning</a> </li>

View File

@ -503,6 +503,27 @@
</url>
<url>
<loc>https://nn.labml.ai/RWKV/configs.html</loc>
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/RWKV/index.html</loc>
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/RWKV/experiment.html</loc>
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/cfr/infoset_saver.html</loc>
<lastmod>2021-06-21T16:30:00+00:00</lastmod>

View File

@ -9,43 +9,34 @@ summary: >
# Receptance Weighted Key Value (RWKV)
##TODO: make colab ?
This is a tutorial/implementation of RWKV
from paper [RWKV: Reinventing RNNs for the Transformer Era](https://arxiv.org/pdf/2305.13048.pdf)
in [PyTorch](https://pytorch.org/).
Full definition of a RWKV Language Model, all of it in this single file.
References:
1) the official RWKV PyTorch implementation released by Bo Peng:
https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
1) [the official RWKV PyTorch implementation released by Bo Peng](https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py)
2) [huggingface/transformers PyTorch implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py)
"""
import math,time
import os
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from labml_helpers.module import Module
PREV_X_TIME = 0
NUM_STATE = 1
DEN_STATE = 2
MAX_STATE = 3
PREV_X_CHANNEL = 4
"""
## Layernorm with bias
"""
class LayerNorm(Module):
"""
### Layer normalization with bias
"""
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
@ -54,15 +45,19 @@ class LayerNorm(Module):
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
"""
# L2 loss wrapper
https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21
"""
class L2Wrap(torch.autograd.Function):
"""
### L2 loss wrapper
[ref](https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21)
"""
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
@ -71,13 +66,15 @@ class L2Wrap(torch.autograd.Function):
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
return grad_output, gy
class ChannelMixing(Module):
"""
## Channel Mixing
### Channel Mixing
"""
def __init__(self,config,layer_id):
def __init__(self, config, layer_id):
super().__init__()
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
# token shifting
@ -88,51 +85,47 @@ class ChannelMixing(Module):
config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
)
## Learnable Matrix
self.key_proj = nn.Linear(n_embd,intermediate_size,bias=False)
self.value_proj = nn.Linear(intermediate_size,n_embd,bias=False)
self.receptance_proj = nn.Linear(n_embd,n_embd,bias=False)
# Learnable Matrix
self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)
## Learnable Vector
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
# Learnable Vector
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
def forward(self,x,state=None):
def forward(self, x, state=None):
"""
# x = (Batch,Time,Channel)
"""
if state is not None:
prev_x = state[self.layer_id,:,[PREV_X_CHANNEL],:]
state[self.layer_id,:,[PREV_X_CHANNEL],:] = x
prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
else:
prev_x = self.time_shift(x)
"""
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
"""
# $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
receptance = self.receptance_proj(receptance)
"""
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
"""
# $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
key = self.key_proj(key)
"""
### $V_t=W_v \cdot max(k_t,0)^2$
"""
# $V_t=W_v \cdot max(k_t,0)^2$
value = self.value_proj(torch.square(torch.relu(key)))
"""
### $o_t=\sigma(r_t) \odot v_t$
"""
# $o_t=\sigma(r_t) \odot v_t$
out = F.sigmoid(receptance) * value
return out, state
"""
## Time Mixing
"""
class TimeMixing(Module):
def __init__(self,config,layer_id):
"""
### Time Mixing
"""
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
@ -141,48 +134,42 @@ class TimeMixing(Module):
n_embd = config.n_embd
attn_sz = n_embd
## learnable matrix
self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
# learnable matrix
self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
## learnable vector
self.time_decay = nn.Parameter(torch.empty(attn_sz))
self.time_first = nn.Parameter(torch.empty(attn_sz))
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
# learnable vector
self.time_decay = nn.Parameter(torch.empty(attn_sz))
self.time_first = nn.Parameter(torch.empty(attn_sz))
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
def forward(self,x,state=None):
# x = (Batch,Time,Channel)
def forward(self, x, state=None):
"""
x = (Batch,Time,Channel)
"""
if state is not None:
prev_x = state[self.layer_id,:,[PREV_X_TIME],:]
state[self.layer_id,:,[PREV_X_TIME],:] = x
prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
state[self.layer_id, :, [PREV_X_TIME], :] = x
else:
prev_x = self.time_shift(x)
"""
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
"""
# $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
receptance = self.receptance_proj(receptance)
"""
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
"""
# $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
key = self.key_proj(key)
"""
### $v_t=W_v \cdot (\mu_v x_t + (1-\mu_v)x_{t-1})$
"""
# $v_t=W_v \cdot (\mu_v x_t + (1-\mu_v)x_{t-1})$
value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
value = self.value_proj(value)
"""
## WKV calculation
"""
# WKV calculation
_, seq_length, _ = key.size()
output = torch.zeros_like(key)
@ -191,9 +178,9 @@ class TimeMixing(Module):
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
else:
num_state = state[self.layer_id,:,NUM_STATE,:]
den_state = state[self.layer_id,:,DEN_STATE,:]
max_state = state[self.layer_id,:,MAX_STATE,:]
num_state = state[self.layer_id, :, NUM_STATE, :]
den_state = state[self.layer_id, :, DEN_STATE, :]
max_state = state[self.layer_id, :, MAX_STATE, :]
time_decay = -torch.exp(self.time_decay)
@ -201,9 +188,7 @@ class TimeMixing(Module):
current_key = key[:, current_index].float()
current_value = value[:, current_index]
"""
### $wkv_t=\frac{\sum^{t-1}_{i=1}d^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum^{t-1}_{i=1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}$
"""
# $wkv_t=\frac{\sum^{t-1}_{i=1}d^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum^{t-1}_{i=1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}$
max_for_output = torch.maximum(max_state, current_key + self.time_first)
e1 = torch.exp(max_state - max_for_output)
e2 = torch.exp(current_key + self.time_first - max_for_output)
@ -219,109 +204,99 @@ class TimeMixing(Module):
den_state = e1 * den_state + e2
max_state = max_for_state
"""
### update states
"""
state[self.layer_id,:,NUM_STATE,:] = num_state
state[self.layer_id,:,DEN_STATE,:] = den_state
state[self.layer_id,:,MAX_STATE,:] = max_state
wkv, state = self.wkv_function(key,value,use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,state=state)
# update states
state[self.layer_id, :, NUM_STATE, :] = num_state
state[self.layer_id, :, DEN_STATE, :] = den_state
state[self.layer_id, :, MAX_STATE, :] = max_state
wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
state=state)
"""
### $o_t=W_o \cdot (\sigma(r_t) \odot wkv_t)$
"""
# $o_t=W_o \cdot (\sigma(r_t) \odot wkv_t)$
rwkv = F.sigmoid(receptance) * wkv
rwkv = self.output_proj(rwkv)
return rwkv, state
"""
## RWKV block element
"""
class Block(Module):
def __init__(self, config,layer_id):
class Block(Module):
"""
## RWKV block element
"""
def __init__(self, config, layer_id):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = TimeMixing(config,layer_id)
self.attn = TimeMixing(config, layer_id)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.ffn = ChannelMixing(config,layer_id)
self.ffn = ChannelMixing(config, layer_id)
def forward(self, x, state = None):
def forward(self, x, state=None):
# state: [batch_size, 5 , n_embd]
"""
## time mixing
"""
# time mixing
residual = x
x,state = self.attn(self.ln_1(x),state=state)
x, state = self.attn(self.ln_1(x), state=state)
x = x + residual
"""
## channel mixing
"""
# channel mixing
residual = x
x, state = self.ffn(self.ln_2(x),state=state)
x, state = self.ffn(self.ln_2(x), state=state)
x = x + residual
return x, state
class RWKV(Module):
def __init__(self, config,lr_init=0.0008):
"""
## RWKV
"""
def __init__(self, config, lr_init=0.0008):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.lr_init = lr_init ## used to initialize embedding parameters
self.lr_init = lr_init ## used to initialize embedding parameters
self.n_layer = config.n_layer
self.n_embd = config.n_embd
"""
## Initiate model layers
"""
self.rwkv = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
ln_p = LayerNorm(config.n_embd, bias=config.bias),
h = nn.ModuleList([Block(config,layer_id) for layer_id in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
))
"""
## Output linear layer
"""
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initiate model layers
self.rwkv = nn.ModuleDict(dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
ln_p=LayerNorm(config.n_embd, bias=config.bias),
h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
ln_f=LayerNorm(config.n_embd, bias=config.bias),
))
# Output linear layer
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
def forward(self, idx, targets=None, state=None, return_state=False):
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
"""
## Embedding Layer
"""
# Embedding Layer
x = self.rwkv.wte(idx)
"""
## Layer Norm
"""
# Layer Norm
x = self.rwkv.ln_p(x)
"""
## RWKV Blocks
"""
for block_idx,block in enumerate(self.rwkv.h):
x, state = block(x,state)
# RWKV Blocks
for block_idx, block in enumerate(self.rwkv.h):
x, state = block(x, state)
x = self.rwkv.ln_f(x)
"""
## Logit Layer and loss Function (for training)
"""
# Logit Layer and loss Function (for training)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
if self.training:
loss = L2Wrap.apply(loss,logits)
loss = L2Wrap.apply(loss, logits)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
"""
## Return Logits and loss
"""
# Return Logits and loss
if return_state:
return logits, loss, state
else:

View File

@ -1,15 +1,8 @@
import copy
import torch.nn as nn
from labml.configs import BaseConfigs, option, calculate, aggregate
from labml_helpers.module import Module
from labml.configs import BaseConfigs
class RWKVConfigs(BaseConfigs):
"""
<a id="TransformerConfigs"></a>
## Transformer Configurations
This defines configurations for a transformer.

View File

@ -1,22 +1,16 @@
import math,time
import os
import inspect
from dataclasses import dataclass
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from labml_nn.RWKV.configs import RWKVConfigs
from __init__ import RWKV
from __init__ import TimeMixing
from labml import experiment
from labml.configs import option
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.optimizers.configs import OptimizerConfigs
from configs import RWKVConfigs
from __init__ import RWKV
from __init__ import TimeMixing
from __init__ import ChannelMixing
from contextlib import nullcontext
class Configs(NLPAutoRegressionConfigs):
"""
@ -41,6 +35,7 @@ class Configs(NLPAutoRegressionConfigs):
beta2: float = 0.95
optimizer = 'rwkv_optimizer'
@option(Configs.rwkv, 'RWKV')
def _rwkv_configs(c: Configs):
"""
@ -58,9 +53,8 @@ def _rwkv_configs(c: Configs):
def _init_weights(module, rwkv: RWKVConfigs):
## initialize Vector Parameters in TimeMixing
if isinstance(module,TimeMixing):
# initialize Vector Parameters in TimeMixing
if isinstance(module, TimeMixing):
layer_id = module.layer_id
n_layer = module.n_layer
n_embd = module.n_embd
@ -85,7 +79,6 @@ def _init_weights(module, rwkv: RWKVConfigs):
module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
@option(Configs.model)
def _model(c: Configs):
"""
@ -101,30 +94,30 @@ def _model(c: Configs):
@option(NLPAutoRegressionConfigs.optimizer)
def _configure_optimizers(c: NLPAutoRegressionConfigs):
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in c.model.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': c.weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and c.device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in c.model.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': c.weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and c.device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer
return optimizer
def main():
@ -154,11 +147,11 @@ def main():
# per epoch
'inner_iterations': 10,
'rwkv.block_size' : 1024,
'rwkv.block_size': 1024,
# model
'rwkv.n_layer' : 12,
'rwkv.n_heads' : 12,
'rwkv.n_embd' : 768
'rwkv.n_layer': 12,
'rwkv.n_heads': 12,
'rwkv.n_embd': 768
})
print(conf.model)