mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
RWKV docs
This commit is contained in:
@ -124,9 +124,7 @@
|
|||||||
<h4>✨ <a href="conv_mixer/index.html">ConvMixer</a></h4>
|
<h4>✨ <a href="conv_mixer/index.html">ConvMixer</a></h4>
|
||||||
<h4>✨ <a href="capsule_networks/index.html">Capsule Networks</a></h4>
|
<h4>✨ <a href="capsule_networks/index.html">Capsule Networks</a></h4>
|
||||||
<h4>✨ <a href="unet/index.html">U-Net</a></h4>
|
<h4>✨ <a href="unet/index.html">U-Net</a></h4>
|
||||||
<h4>✨ <a href="sketch_rnn/index.html">RNNs</a></h4>
|
<h4>✨ <a href="sketch_rnn/index.html">Sketch RNN</a></h4>
|
||||||
<ul><li><a href="rwkv/index.html">RWKV</a> </li>
|
|
||||||
<li><a href="sketch_rnn/index.html">Sketch RNN</a></li></ul>
|
|
||||||
<h4>✨ Graph Neural Networks</h4>
|
<h4>✨ Graph Neural Networks</h4>
|
||||||
<ul><li><a href="graphs/gat/index.html">Graph Attention Networks (GAT)</a> </li>
|
<ul><li><a href="graphs/gat/index.html">Graph Attention Networks (GAT)</a> </li>
|
||||||
<li><a href="graphs/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></li></ul>
|
<li><a href="graphs/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></li></ul>
|
||||||
@ -170,7 +168,6 @@
|
|||||||
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
|
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
|
||||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
|
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
|
||||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf">ZeRO: Memory Optimizations Toward Training Trillion Parameter Models</a> </li>
|
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf">ZeRO: Memory Optimizations Toward Training Trillion Parameter Models</a> </li>
|
||||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/RWKV.pdf">RWKV: Reinventing RNNs for the Transformer Era</a> </li>
|
|
||||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf">PaLM: Scaling Language Modeling with Pathways</a> </li>
|
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf">PaLM: Scaling Language Modeling with Pathways</a> </li>
|
||||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf">Hierarchical Text-Conditional Image Generation with CLIP Latents</a> </li>
|
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf">Hierarchical Text-Conditional Image Generation with CLIP Latents</a> </li>
|
||||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf">STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning</a> </li>
|
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf">STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning</a> </li>
|
||||||
|
@ -503,6 +503,27 @@
|
|||||||
</url>
|
</url>
|
||||||
|
|
||||||
|
|
||||||
|
<url>
|
||||||
|
<loc>https://nn.labml.ai/RWKV/configs.html</loc>
|
||||||
|
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
|
||||||
|
<priority>1.00</priority>
|
||||||
|
</url>
|
||||||
|
|
||||||
|
|
||||||
|
<url>
|
||||||
|
<loc>https://nn.labml.ai/RWKV/index.html</loc>
|
||||||
|
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
|
||||||
|
<priority>1.00</priority>
|
||||||
|
</url>
|
||||||
|
|
||||||
|
|
||||||
|
<url>
|
||||||
|
<loc>https://nn.labml.ai/RWKV/experiment.html</loc>
|
||||||
|
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
|
||||||
|
<priority>1.00</priority>
|
||||||
|
</url>
|
||||||
|
|
||||||
|
|
||||||
<url>
|
<url>
|
||||||
<loc>https://nn.labml.ai/cfr/infoset_saver.html</loc>
|
<loc>https://nn.labml.ai/cfr/infoset_saver.html</loc>
|
||||||
<lastmod>2021-06-21T16:30:00+00:00</lastmod>
|
<lastmod>2021-06-21T16:30:00+00:00</lastmod>
|
||||||
|
@ -9,60 +9,55 @@ summary: >
|
|||||||
|
|
||||||
# Receptance Weighted Key Value (RWKV)
|
# Receptance Weighted Key Value (RWKV)
|
||||||
|
|
||||||
##TODO: make colab ?
|
|
||||||
|
|
||||||
This is a tutorial/implementation of RWKV
|
This is a tutorial/implementation of RWKV
|
||||||
from paper [RWKV: Reinventing RNNs for the Transformer Era](https://arxiv.org/pdf/2305.13048.pdf)
|
from paper [RWKV: Reinventing RNNs for the Transformer Era](https://arxiv.org/pdf/2305.13048.pdf)
|
||||||
in [PyTorch](https://pytorch.org/).
|
in [PyTorch](https://pytorch.org/).
|
||||||
|
|
||||||
Full definition of a RWKV Language Model, all of it in this single file.
|
Full definition of a RWKV Language Model, all of it in this single file.
|
||||||
References:
|
References:
|
||||||
1) the official RWKV PyTorch implementation released by Bo Peng:
|
1) [the official RWKV PyTorch implementation released by Bo Peng](https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py)
|
||||||
https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py
|
2) [huggingface/transformers PyTorch implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py)
|
||||||
2) huggingface/transformers PyTorch implementation:
|
|
||||||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import math,time
|
|
||||||
import os
|
|
||||||
import inspect
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
|
|
||||||
PREV_X_TIME = 0
|
PREV_X_TIME = 0
|
||||||
NUM_STATE = 1
|
NUM_STATE = 1
|
||||||
DEN_STATE = 2
|
DEN_STATE = 2
|
||||||
MAX_STATE = 3
|
MAX_STATE = 3
|
||||||
PREV_X_CHANNEL = 4
|
PREV_X_CHANNEL = 4
|
||||||
|
|
||||||
"""
|
|
||||||
## Layernorm with bias
|
|
||||||
"""
|
|
||||||
class LayerNorm(Module):
|
class LayerNorm(Module):
|
||||||
|
"""
|
||||||
|
### Layer normalization with bias
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, ndim, bias):
|
def __init__(self, ndim, bias):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(ndim))
|
self.weight = nn.Parameter(torch.ones(ndim))
|
||||||
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# L2 loss wrapper
|
|
||||||
https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21
|
|
||||||
"""
|
|
||||||
class L2Wrap(torch.autograd.Function):
|
class L2Wrap(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
### L2 loss wrapper
|
||||||
|
|
||||||
|
[ref](https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21)
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, loss, y):
|
def forward(ctx, loss, y):
|
||||||
ctx.save_for_backward(y)
|
ctx.save_for_backward(y)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
y = ctx.saved_tensors[0]
|
y = ctx.saved_tensors[0]
|
||||||
@ -71,118 +66,110 @@ class L2Wrap(torch.autograd.Function):
|
|||||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||||
gy = torch.zeros_like(y)
|
gy = torch.zeros_like(y)
|
||||||
gy.scatter_(-1, ids, maxx * factor)
|
gy.scatter_(-1, ids, maxx * factor)
|
||||||
return (grad_output, gy)
|
return grad_output, gy
|
||||||
|
|
||||||
|
|
||||||
class ChannelMixing(Module):
|
class ChannelMixing(Module):
|
||||||
"""
|
"""
|
||||||
## Channel Mixing
|
### Channel Mixing
|
||||||
"""
|
"""
|
||||||
def __init__(self,config,layer_id):
|
|
||||||
|
def __init__(self, config, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||||
# token shifting
|
# token shifting
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
|
||||||
n_embd = config.n_embd
|
n_embd = config.n_embd
|
||||||
intermediate_size = (
|
intermediate_size = (
|
||||||
config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
|
config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
|
||||||
)
|
)
|
||||||
|
|
||||||
## Learnable Matrix
|
# Learnable Matrix
|
||||||
self.key_proj = nn.Linear(n_embd,intermediate_size,bias=False)
|
self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
|
||||||
self.value_proj = nn.Linear(intermediate_size,n_embd,bias=False)
|
self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
|
||||||
self.receptance_proj = nn.Linear(n_embd,n_embd,bias=False)
|
self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)
|
||||||
|
|
||||||
## Learnable Vector
|
# Learnable Vector
|
||||||
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
|
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||||
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
|
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||||
|
|
||||||
def forward(self,x,state=None):
|
def forward(self, x, state=None):
|
||||||
|
"""
|
||||||
# x = (Batch,Time,Channel)
|
# x = (Batch,Time,Channel)
|
||||||
|
"""
|
||||||
if state is not None:
|
if state is not None:
|
||||||
prev_x = state[self.layer_id,:,[PREV_X_CHANNEL],:]
|
prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
|
||||||
state[self.layer_id,:,[PREV_X_CHANNEL],:] = x
|
state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
|
||||||
else:
|
else:
|
||||||
prev_x = self.time_shift(x)
|
prev_x = self.time_shift(x)
|
||||||
|
|
||||||
"""
|
# $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
|
||||||
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
|
|
||||||
"""
|
|
||||||
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
|
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
|
||||||
receptance = self.receptance_proj(receptance)
|
receptance = self.receptance_proj(receptance)
|
||||||
|
|
||||||
"""
|
# $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
|
||||||
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
|
|
||||||
"""
|
|
||||||
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
|
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
|
||||||
key = self.key_proj(key)
|
key = self.key_proj(key)
|
||||||
|
|
||||||
"""
|
# $V_t=W_v \cdot max(k_t,0)^2$
|
||||||
### $V_t=W_v \cdot max(k_t,0)^2$
|
|
||||||
"""
|
|
||||||
value = self.value_proj(torch.square(torch.relu(key)))
|
value = self.value_proj(torch.square(torch.relu(key)))
|
||||||
|
|
||||||
"""
|
# $o_t=\sigma(r_t) \odot v_t$
|
||||||
### $o_t=\sigma(r_t) \odot v_t$
|
|
||||||
"""
|
|
||||||
out = F.sigmoid(receptance) * value
|
out = F.sigmoid(receptance) * value
|
||||||
return out, state
|
return out, state
|
||||||
|
|
||||||
"""
|
|
||||||
## Time Mixing
|
|
||||||
"""
|
|
||||||
class TimeMixing(Module):
|
class TimeMixing(Module):
|
||||||
def __init__(self,config,layer_id):
|
"""
|
||||||
|
### Time Mixing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
|
||||||
n_embd = config.n_embd
|
n_embd = config.n_embd
|
||||||
attn_sz = n_embd
|
attn_sz = n_embd
|
||||||
|
|
||||||
## learnable matrix
|
# learnable matrix
|
||||||
self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
||||||
self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
||||||
self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
||||||
self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
|
self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
|
||||||
|
|
||||||
## learnable vector
|
# learnable vector
|
||||||
self.time_decay = nn.Parameter(torch.empty(attn_sz))
|
self.time_decay = nn.Parameter(torch.empty(attn_sz))
|
||||||
self.time_first = nn.Parameter(torch.empty(attn_sz))
|
self.time_first = nn.Parameter(torch.empty(attn_sz))
|
||||||
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
|
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||||
self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
|
self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||||
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
|
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||||
|
|
||||||
def forward(self,x,state=None):
|
def forward(self, x, state=None):
|
||||||
# x = (Batch,Time,Channel)
|
"""
|
||||||
|
x = (Batch,Time,Channel)
|
||||||
|
"""
|
||||||
if state is not None:
|
if state is not None:
|
||||||
prev_x = state[self.layer_id,:,[PREV_X_TIME],:]
|
prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
|
||||||
state[self.layer_id,:,[PREV_X_TIME],:] = x
|
state[self.layer_id, :, [PREV_X_TIME], :] = x
|
||||||
else:
|
else:
|
||||||
prev_x = self.time_shift(x)
|
prev_x = self.time_shift(x)
|
||||||
|
|
||||||
"""
|
# $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
|
||||||
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
|
|
||||||
"""
|
|
||||||
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
|
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
|
||||||
receptance = self.receptance_proj(receptance)
|
receptance = self.receptance_proj(receptance)
|
||||||
|
|
||||||
"""
|
# $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
|
||||||
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
|
|
||||||
"""
|
|
||||||
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
|
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
|
||||||
key = self.key_proj(key)
|
key = self.key_proj(key)
|
||||||
|
|
||||||
"""
|
# $v_t=W_v \cdot (\mu_v x_t + (1-\mu_v)x_{t-1})$
|
||||||
### $v_t=W_v \cdot (\mu_v x_t + (1-\mu_v)x_{t-1})$
|
|
||||||
"""
|
|
||||||
value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
|
value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
|
||||||
value = self.value_proj(value)
|
value = self.value_proj(value)
|
||||||
|
|
||||||
"""
|
# WKV calculation
|
||||||
## WKV calculation
|
|
||||||
"""
|
|
||||||
_, seq_length, _ = key.size()
|
_, seq_length, _ = key.size()
|
||||||
output = torch.zeros_like(key)
|
output = torch.zeros_like(key)
|
||||||
|
|
||||||
@ -191,9 +178,9 @@ class TimeMixing(Module):
|
|||||||
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
|
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
|
||||||
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
|
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
|
||||||
else:
|
else:
|
||||||
num_state = state[self.layer_id,:,NUM_STATE,:]
|
num_state = state[self.layer_id, :, NUM_STATE, :]
|
||||||
den_state = state[self.layer_id,:,DEN_STATE,:]
|
den_state = state[self.layer_id, :, DEN_STATE, :]
|
||||||
max_state = state[self.layer_id,:,MAX_STATE,:]
|
max_state = state[self.layer_id, :, MAX_STATE, :]
|
||||||
|
|
||||||
time_decay = -torch.exp(self.time_decay)
|
time_decay = -torch.exp(self.time_decay)
|
||||||
|
|
||||||
@ -201,9 +188,7 @@ class TimeMixing(Module):
|
|||||||
current_key = key[:, current_index].float()
|
current_key = key[:, current_index].float()
|
||||||
current_value = value[:, current_index]
|
current_value = value[:, current_index]
|
||||||
|
|
||||||
"""
|
# $wkv_t=\frac{\sum^{t-1}_{i=1}d^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum^{t-1}_{i=1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}$
|
||||||
### $wkv_t=\frac{\sum^{t-1}_{i=1}d^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum^{t-1}_{i=1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}$
|
|
||||||
"""
|
|
||||||
max_for_output = torch.maximum(max_state, current_key + self.time_first)
|
max_for_output = torch.maximum(max_state, current_key + self.time_first)
|
||||||
e1 = torch.exp(max_state - max_for_output)
|
e1 = torch.exp(max_state - max_for_output)
|
||||||
e2 = torch.exp(current_key + self.time_first - max_for_output)
|
e2 = torch.exp(current_key + self.time_first - max_for_output)
|
||||||
@ -218,110 +203,100 @@ class TimeMixing(Module):
|
|||||||
num_state = e1 * num_state + e2 * current_value
|
num_state = e1 * num_state + e2 * current_value
|
||||||
den_state = e1 * den_state + e2
|
den_state = e1 * den_state + e2
|
||||||
max_state = max_for_state
|
max_state = max_for_state
|
||||||
|
|
||||||
"""
|
# update states
|
||||||
### update states
|
state[self.layer_id, :, NUM_STATE, :] = num_state
|
||||||
"""
|
state[self.layer_id, :, DEN_STATE, :] = den_state
|
||||||
state[self.layer_id,:,NUM_STATE,:] = num_state
|
state[self.layer_id, :, MAX_STATE, :] = max_state
|
||||||
state[self.layer_id,:,DEN_STATE,:] = den_state
|
wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
|
||||||
state[self.layer_id,:,MAX_STATE,:] = max_state
|
state=state)
|
||||||
wkv, state = self.wkv_function(key,value,use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,state=state)
|
|
||||||
|
# $o_t=W_o \cdot (\sigma(r_t) \odot wkv_t)$
|
||||||
"""
|
|
||||||
### $o_t=W_o \cdot (\sigma(r_t) \odot wkv_t)$
|
|
||||||
"""
|
|
||||||
rwkv = F.sigmoid(receptance) * wkv
|
rwkv = F.sigmoid(receptance) * wkv
|
||||||
rwkv = self.output_proj(rwkv)
|
rwkv = self.output_proj(rwkv)
|
||||||
|
|
||||||
return rwkv, state
|
return rwkv, state
|
||||||
|
|
||||||
"""
|
|
||||||
## RWKV block element
|
|
||||||
"""
|
|
||||||
class Block(Module):
|
|
||||||
|
|
||||||
def __init__(self, config,layer_id):
|
class Block(Module):
|
||||||
|
"""
|
||||||
|
## RWKV block element
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
||||||
self.attn = TimeMixing(config,layer_id)
|
self.attn = TimeMixing(config, layer_id)
|
||||||
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
||||||
self.ffn = ChannelMixing(config,layer_id)
|
self.ffn = ChannelMixing(config, layer_id)
|
||||||
|
|
||||||
def forward(self, x, state = None):
|
def forward(self, x, state=None):
|
||||||
# state: [batch_size, 5 , n_embd]
|
# state: [batch_size, 5 , n_embd]
|
||||||
"""
|
|
||||||
## time mixing
|
# time mixing
|
||||||
"""
|
|
||||||
residual = x
|
residual = x
|
||||||
x,state = self.attn(self.ln_1(x),state=state)
|
x, state = self.attn(self.ln_1(x), state=state)
|
||||||
x = x + residual
|
x = x + residual
|
||||||
"""
|
|
||||||
## channel mixing
|
# channel mixing
|
||||||
"""
|
|
||||||
residual = x
|
residual = x
|
||||||
x, state = self.ffn(self.ln_2(x),state=state)
|
x, state = self.ffn(self.ln_2(x), state=state)
|
||||||
x = x + residual
|
x = x + residual
|
||||||
return x, state
|
return x, state
|
||||||
|
|
||||||
|
|
||||||
class RWKV(Module):
|
class RWKV(Module):
|
||||||
def __init__(self, config,lr_init=0.0008):
|
"""
|
||||||
|
## RWKV
|
||||||
|
"""
|
||||||
|
def __init__(self, config, lr_init=0.0008):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert config.vocab_size is not None
|
assert config.vocab_size is not None
|
||||||
assert config.block_size is not None
|
assert config.block_size is not None
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lr_init = lr_init ## used to initialize embedding parameters
|
self.lr_init = lr_init ## used to initialize embedding parameters
|
||||||
self.n_layer = config.n_layer
|
self.n_layer = config.n_layer
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
"""
|
|
||||||
## Initiate model layers
|
|
||||||
"""
|
|
||||||
self.rwkv = nn.ModuleDict(dict(
|
|
||||||
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
|
||||||
ln_p = LayerNorm(config.n_embd, bias=config.bias),
|
|
||||||
h = nn.ModuleList([Block(config,layer_id) for layer_id in range(config.n_layer)]),
|
|
||||||
ln_f = LayerNorm(config.n_embd, bias=config.bias),
|
|
||||||
))
|
|
||||||
"""
|
|
||||||
## Output linear layer
|
|
||||||
"""
|
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
|
# Initiate model layers
|
||||||
|
self.rwkv = nn.ModuleDict(dict(
|
||||||
|
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
||||||
|
ln_p=LayerNorm(config.n_embd, bias=config.bias),
|
||||||
|
h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
|
||||||
|
ln_f=LayerNorm(config.n_embd, bias=config.bias),
|
||||||
|
))
|
||||||
|
|
||||||
|
# Output linear layer
|
||||||
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def forward(self, idx, targets=None, state=None, return_state=False):
|
def forward(self, idx, targets=None, state=None, return_state=False):
|
||||||
b, t = idx.size()
|
b, t = idx.size()
|
||||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||||
|
|
||||||
"""
|
# Embedding Layer
|
||||||
## Embedding Layer
|
|
||||||
"""
|
|
||||||
x = self.rwkv.wte(idx)
|
x = self.rwkv.wte(idx)
|
||||||
"""
|
|
||||||
## Layer Norm
|
# Layer Norm
|
||||||
"""
|
|
||||||
x = self.rwkv.ln_p(x)
|
x = self.rwkv.ln_p(x)
|
||||||
"""
|
|
||||||
## RWKV Blocks
|
# RWKV Blocks
|
||||||
"""
|
for block_idx, block in enumerate(self.rwkv.h):
|
||||||
for block_idx,block in enumerate(self.rwkv.h):
|
x, state = block(x, state)
|
||||||
x, state = block(x,state)
|
|
||||||
x = self.rwkv.ln_f(x)
|
x = self.rwkv.ln_f(x)
|
||||||
|
|
||||||
"""
|
# Logit Layer and loss Function (for training)
|
||||||
## Logit Layer and loss Function (for training)
|
|
||||||
"""
|
|
||||||
if targets is not None:
|
if targets is not None:
|
||||||
# if we are given some desired targets also calculate the loss
|
# if we are given some desired targets also calculate the loss
|
||||||
logits = self.lm_head(x)
|
logits = self.lm_head(x)
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||||
if self.training:
|
if self.training:
|
||||||
loss = L2Wrap.apply(loss,logits)
|
loss = L2Wrap.apply(loss, logits)
|
||||||
else:
|
else:
|
||||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||||
loss = None
|
loss = None
|
||||||
"""
|
|
||||||
## Return Logits and loss
|
# Return Logits and loss
|
||||||
"""
|
|
||||||
if return_state:
|
if return_state:
|
||||||
return logits, loss, state
|
return logits, loss, state
|
||||||
else:
|
else:
|
||||||
|
@ -1,15 +1,8 @@
|
|||||||
import copy
|
from labml.configs import BaseConfigs
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from labml.configs import BaseConfigs, option, calculate, aggregate
|
|
||||||
from labml_helpers.module import Module
|
|
||||||
|
|
||||||
|
|
||||||
class RWKVConfigs(BaseConfigs):
|
class RWKVConfigs(BaseConfigs):
|
||||||
"""
|
"""
|
||||||
<a id="TransformerConfigs"></a>
|
|
||||||
|
|
||||||
## Transformer Configurations
|
## Transformer Configurations
|
||||||
|
|
||||||
This defines configurations for a transformer.
|
This defines configurations for a transformer.
|
||||||
|
@ -1,22 +1,16 @@
|
|||||||
import math,time
|
|
||||||
import os
|
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from labml_nn.RWKV.configs import RWKVConfigs
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
from __init__ import RWKV
|
||||||
|
from __init__ import TimeMixing
|
||||||
from labml import experiment
|
from labml import experiment
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
|
||||||
from configs import RWKVConfigs
|
|
||||||
from __init__ import RWKV
|
|
||||||
from __init__ import TimeMixing
|
|
||||||
from __init__ import ChannelMixing
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
class Configs(NLPAutoRegressionConfigs):
|
class Configs(NLPAutoRegressionConfigs):
|
||||||
"""
|
"""
|
||||||
@ -33,7 +27,7 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
# number of warmup iterations
|
# number of warmup iterations
|
||||||
warmup_iters: int = 2000
|
warmup_iters: int = 2000
|
||||||
# total number of training iterations
|
# total number of training iterations
|
||||||
max_iters: int = 600000
|
max_iters: int = 600000
|
||||||
# weight decay
|
# weight decay
|
||||||
weight_decay: float = 1e-1
|
weight_decay: float = 1e-1
|
||||||
# Custom optimizer
|
# Custom optimizer
|
||||||
@ -41,6 +35,7 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
beta2: float = 0.95
|
beta2: float = 0.95
|
||||||
optimizer = 'rwkv_optimizer'
|
optimizer = 'rwkv_optimizer'
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.rwkv, 'RWKV')
|
@option(Configs.rwkv, 'RWKV')
|
||||||
def _rwkv_configs(c: Configs):
|
def _rwkv_configs(c: Configs):
|
||||||
"""
|
"""
|
||||||
@ -58,14 +53,13 @@ def _rwkv_configs(c: Configs):
|
|||||||
|
|
||||||
|
|
||||||
def _init_weights(module, rwkv: RWKVConfigs):
|
def _init_weights(module, rwkv: RWKVConfigs):
|
||||||
|
# initialize Vector Parameters in TimeMixing
|
||||||
## initialize Vector Parameters in TimeMixing
|
if isinstance(module, TimeMixing):
|
||||||
if isinstance(module,TimeMixing):
|
|
||||||
layer_id = module.layer_id
|
layer_id = module.layer_id
|
||||||
n_layer = module.n_layer
|
n_layer = module.n_layer
|
||||||
n_embd = module.n_embd
|
n_embd = module.n_embd
|
||||||
attn_sz = n_embd
|
attn_sz = n_embd
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
|
ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
|
||||||
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
|
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
|
||||||
@ -83,7 +77,6 @@ def _init_weights(module, rwkv: RWKVConfigs):
|
|||||||
module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||||
module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||||
module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.model)
|
@option(Configs.model)
|
||||||
@ -101,30 +94,30 @@ def _model(c: Configs):
|
|||||||
|
|
||||||
@option(NLPAutoRegressionConfigs.optimizer)
|
@option(NLPAutoRegressionConfigs.optimizer)
|
||||||
def _configure_optimizers(c: NLPAutoRegressionConfigs):
|
def _configure_optimizers(c: NLPAutoRegressionConfigs):
|
||||||
# start with all of the candidate parameters
|
# start with all of the candidate parameters
|
||||||
param_dict = {pn: p for pn, p in c.model.named_parameters()}
|
param_dict = {pn: p for pn, p in c.model.named_parameters()}
|
||||||
# filter out those that do not require grad
|
# filter out those that do not require grad
|
||||||
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
||||||
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
||||||
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
||||||
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
||||||
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
||||||
optim_groups = [
|
optim_groups = [
|
||||||
{'params': decay_params, 'weight_decay': c.weight_decay},
|
{'params': decay_params, 'weight_decay': c.weight_decay},
|
||||||
{'params': nodecay_params, 'weight_decay': 0.0}
|
{'params': nodecay_params, 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
num_decay_params = sum(p.numel() for p in decay_params)
|
num_decay_params = sum(p.numel() for p in decay_params)
|
||||||
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
||||||
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
||||||
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
||||||
# Create AdamW optimizer and use the fused version if it is available
|
# Create AdamW optimizer and use the fused version if it is available
|
||||||
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
||||||
use_fused = fused_available and c.device_type == 'cuda'
|
use_fused = fused_available and c.device_type == 'cuda'
|
||||||
extra_args = dict(fused=True) if use_fused else dict()
|
extra_args = dict(fused=True) if use_fused else dict()
|
||||||
optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
|
optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
|
||||||
print(f"using fused AdamW: {use_fused}")
|
print(f"using fused AdamW: {use_fused}")
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -154,11 +147,11 @@ def main():
|
|||||||
# per epoch
|
# per epoch
|
||||||
'inner_iterations': 10,
|
'inner_iterations': 10,
|
||||||
|
|
||||||
'rwkv.block_size' : 1024,
|
'rwkv.block_size': 1024,
|
||||||
# model
|
# model
|
||||||
'rwkv.n_layer' : 12,
|
'rwkv.n_layer': 12,
|
||||||
'rwkv.n_heads' : 12,
|
'rwkv.n_heads': 12,
|
||||||
'rwkv.n_embd' : 768
|
'rwkv.n_embd': 768
|
||||||
})
|
})
|
||||||
|
|
||||||
print(conf.model)
|
print(conf.model)
|
||||||
|
Reference in New Issue
Block a user