mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
RWKV docs
This commit is contained in:
@ -124,9 +124,7 @@
|
||||
<h4>✨ <a href="conv_mixer/index.html">ConvMixer</a></h4>
|
||||
<h4>✨ <a href="capsule_networks/index.html">Capsule Networks</a></h4>
|
||||
<h4>✨ <a href="unet/index.html">U-Net</a></h4>
|
||||
<h4>✨ <a href="sketch_rnn/index.html">RNNs</a></h4>
|
||||
<ul><li><a href="rwkv/index.html">RWKV</a> </li>
|
||||
<li><a href="sketch_rnn/index.html">Sketch RNN</a></li></ul>
|
||||
<h4>✨ <a href="sketch_rnn/index.html">Sketch RNN</a></h4>
|
||||
<h4>✨ Graph Neural Networks</h4>
|
||||
<ul><li><a href="graphs/gat/index.html">Graph Attention Networks (GAT)</a> </li>
|
||||
<li><a href="graphs/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></li></ul>
|
||||
@ -170,7 +168,6 @@
|
||||
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
|
||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
|
||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf">ZeRO: Memory Optimizations Toward Training Trillion Parameter Models</a> </li>
|
||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/RWKV.pdf">RWKV: Reinventing RNNs for the Transformer Era</a> </li>
|
||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf">PaLM: Scaling Language Modeling with Pathways</a> </li>
|
||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf">Hierarchical Text-Conditional Image Generation with CLIP Latents</a> </li>
|
||||
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf">STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning</a> </li>
|
||||
|
@ -503,6 +503,27 @@
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/RWKV/configs.html</loc>
|
||||
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/RWKV/index.html</loc>
|
||||
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/RWKV/experiment.html</loc>
|
||||
<lastmod>2024-03-17T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/cfr/infoset_saver.html</loc>
|
||||
<lastmod>2021-06-21T16:30:00+00:00</lastmod>
|
||||
|
@ -9,43 +9,34 @@ summary: >
|
||||
|
||||
# Receptance Weighted Key Value (RWKV)
|
||||
|
||||
##TODO: make colab ?
|
||||
|
||||
This is a tutorial/implementation of RWKV
|
||||
from paper [RWKV: Reinventing RNNs for the Transformer Era](https://arxiv.org/pdf/2305.13048.pdf)
|
||||
in [PyTorch](https://pytorch.org/).
|
||||
|
||||
Full definition of a RWKV Language Model, all of it in this single file.
|
||||
References:
|
||||
1) the official RWKV PyTorch implementation released by Bo Peng:
|
||||
https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py
|
||||
2) huggingface/transformers PyTorch implementation:
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
|
||||
1) [the official RWKV PyTorch implementation released by Bo Peng](https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py)
|
||||
2) [huggingface/transformers PyTorch implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py)
|
||||
"""
|
||||
|
||||
|
||||
import math,time
|
||||
import os
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
PREV_X_TIME = 0
|
||||
NUM_STATE = 1
|
||||
DEN_STATE = 2
|
||||
MAX_STATE = 3
|
||||
PREV_X_CHANNEL = 4
|
||||
|
||||
"""
|
||||
## Layernorm with bias
|
||||
"""
|
||||
|
||||
class LayerNorm(Module):
|
||||
"""
|
||||
### Layer normalization with bias
|
||||
"""
|
||||
|
||||
def __init__(self, ndim, bias):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(ndim))
|
||||
@ -54,15 +45,19 @@ class LayerNorm(Module):
|
||||
def forward(self, input):
|
||||
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
||||
|
||||
"""
|
||||
# L2 loss wrapper
|
||||
https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21
|
||||
"""
|
||||
|
||||
class L2Wrap(torch.autograd.Function):
|
||||
"""
|
||||
### L2 loss wrapper
|
||||
|
||||
[ref](https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, loss, y):
|
||||
ctx.save_for_backward(y)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
y = ctx.saved_tensors[0]
|
||||
@ -71,13 +66,15 @@ class L2Wrap(torch.autograd.Function):
|
||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||
gy = torch.zeros_like(y)
|
||||
gy.scatter_(-1, ids, maxx * factor)
|
||||
return (grad_output, gy)
|
||||
return grad_output, gy
|
||||
|
||||
|
||||
class ChannelMixing(Module):
|
||||
"""
|
||||
## Channel Mixing
|
||||
### Channel Mixing
|
||||
"""
|
||||
def __init__(self,config,layer_id):
|
||||
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
# token shifting
|
||||
@ -88,51 +85,47 @@ class ChannelMixing(Module):
|
||||
config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
|
||||
)
|
||||
|
||||
## Learnable Matrix
|
||||
self.key_proj = nn.Linear(n_embd,intermediate_size,bias=False)
|
||||
self.value_proj = nn.Linear(intermediate_size,n_embd,bias=False)
|
||||
self.receptance_proj = nn.Linear(n_embd,n_embd,bias=False)
|
||||
# Learnable Matrix
|
||||
self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
|
||||
self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
|
||||
self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)
|
||||
|
||||
## Learnable Vector
|
||||
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
# Learnable Vector
|
||||
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
|
||||
def forward(self,x,state=None):
|
||||
def forward(self, x, state=None):
|
||||
"""
|
||||
# x = (Batch,Time,Channel)
|
||||
"""
|
||||
if state is not None:
|
||||
prev_x = state[self.layer_id,:,[PREV_X_CHANNEL],:]
|
||||
state[self.layer_id,:,[PREV_X_CHANNEL],:] = x
|
||||
prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
|
||||
state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
|
||||
else:
|
||||
prev_x = self.time_shift(x)
|
||||
|
||||
"""
|
||||
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
|
||||
"""
|
||||
# $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
|
||||
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
|
||||
receptance = self.receptance_proj(receptance)
|
||||
|
||||
"""
|
||||
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
|
||||
"""
|
||||
# $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
|
||||
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
|
||||
key = self.key_proj(key)
|
||||
|
||||
"""
|
||||
### $V_t=W_v \cdot max(k_t,0)^2$
|
||||
"""
|
||||
# $V_t=W_v \cdot max(k_t,0)^2$
|
||||
value = self.value_proj(torch.square(torch.relu(key)))
|
||||
|
||||
"""
|
||||
### $o_t=\sigma(r_t) \odot v_t$
|
||||
"""
|
||||
# $o_t=\sigma(r_t) \odot v_t$
|
||||
out = F.sigmoid(receptance) * value
|
||||
return out, state
|
||||
|
||||
"""
|
||||
## Time Mixing
|
||||
"""
|
||||
|
||||
class TimeMixing(Module):
|
||||
def __init__(self,config,layer_id):
|
||||
"""
|
||||
### Time Mixing
|
||||
"""
|
||||
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
@ -141,48 +134,42 @@ class TimeMixing(Module):
|
||||
n_embd = config.n_embd
|
||||
attn_sz = n_embd
|
||||
|
||||
## learnable matrix
|
||||
self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
||||
self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
||||
# learnable matrix
|
||||
self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
||||
self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
||||
self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
|
||||
self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
|
||||
self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
|
||||
|
||||
## learnable vector
|
||||
self.time_decay = nn.Parameter(torch.empty(attn_sz))
|
||||
self.time_first = nn.Parameter(torch.empty(attn_sz))
|
||||
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
# learnable vector
|
||||
self.time_decay = nn.Parameter(torch.empty(attn_sz))
|
||||
self.time_first = nn.Parameter(torch.empty(attn_sz))
|
||||
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
|
||||
|
||||
def forward(self,x,state=None):
|
||||
# x = (Batch,Time,Channel)
|
||||
def forward(self, x, state=None):
|
||||
"""
|
||||
x = (Batch,Time,Channel)
|
||||
"""
|
||||
if state is not None:
|
||||
prev_x = state[self.layer_id,:,[PREV_X_TIME],:]
|
||||
state[self.layer_id,:,[PREV_X_TIME],:] = x
|
||||
prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
|
||||
state[self.layer_id, :, [PREV_X_TIME], :] = x
|
||||
else:
|
||||
prev_x = self.time_shift(x)
|
||||
|
||||
"""
|
||||
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
|
||||
"""
|
||||
# $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
|
||||
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
|
||||
receptance = self.receptance_proj(receptance)
|
||||
|
||||
"""
|
||||
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
|
||||
"""
|
||||
# $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
|
||||
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
|
||||
key = self.key_proj(key)
|
||||
|
||||
"""
|
||||
### $v_t=W_v \cdot (\mu_v x_t + (1-\mu_v)x_{t-1})$
|
||||
"""
|
||||
# $v_t=W_v \cdot (\mu_v x_t + (1-\mu_v)x_{t-1})$
|
||||
value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
|
||||
value = self.value_proj(value)
|
||||
|
||||
"""
|
||||
## WKV calculation
|
||||
"""
|
||||
# WKV calculation
|
||||
_, seq_length, _ = key.size()
|
||||
output = torch.zeros_like(key)
|
||||
|
||||
@ -191,9 +178,9 @@ class TimeMixing(Module):
|
||||
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
|
||||
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
|
||||
else:
|
||||
num_state = state[self.layer_id,:,NUM_STATE,:]
|
||||
den_state = state[self.layer_id,:,DEN_STATE,:]
|
||||
max_state = state[self.layer_id,:,MAX_STATE,:]
|
||||
num_state = state[self.layer_id, :, NUM_STATE, :]
|
||||
den_state = state[self.layer_id, :, DEN_STATE, :]
|
||||
max_state = state[self.layer_id, :, MAX_STATE, :]
|
||||
|
||||
time_decay = -torch.exp(self.time_decay)
|
||||
|
||||
@ -201,9 +188,7 @@ class TimeMixing(Module):
|
||||
current_key = key[:, current_index].float()
|
||||
current_value = value[:, current_index]
|
||||
|
||||
"""
|
||||
### $wkv_t=\frac{\sum^{t-1}_{i=1}d^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum^{t-1}_{i=1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}$
|
||||
"""
|
||||
# $wkv_t=\frac{\sum^{t-1}_{i=1}d^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum^{t-1}_{i=1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}$
|
||||
max_for_output = torch.maximum(max_state, current_key + self.time_first)
|
||||
e1 = torch.exp(max_state - max_for_output)
|
||||
e2 = torch.exp(current_key + self.time_first - max_for_output)
|
||||
@ -219,109 +204,99 @@ class TimeMixing(Module):
|
||||
den_state = e1 * den_state + e2
|
||||
max_state = max_for_state
|
||||
|
||||
"""
|
||||
### update states
|
||||
"""
|
||||
state[self.layer_id,:,NUM_STATE,:] = num_state
|
||||
state[self.layer_id,:,DEN_STATE,:] = den_state
|
||||
state[self.layer_id,:,MAX_STATE,:] = max_state
|
||||
wkv, state = self.wkv_function(key,value,use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,state=state)
|
||||
# update states
|
||||
state[self.layer_id, :, NUM_STATE, :] = num_state
|
||||
state[self.layer_id, :, DEN_STATE, :] = den_state
|
||||
state[self.layer_id, :, MAX_STATE, :] = max_state
|
||||
wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
|
||||
state=state)
|
||||
|
||||
"""
|
||||
### $o_t=W_o \cdot (\sigma(r_t) \odot wkv_t)$
|
||||
"""
|
||||
# $o_t=W_o \cdot (\sigma(r_t) \odot wkv_t)$
|
||||
rwkv = F.sigmoid(receptance) * wkv
|
||||
rwkv = self.output_proj(rwkv)
|
||||
|
||||
return rwkv, state
|
||||
|
||||
"""
|
||||
## RWKV block element
|
||||
"""
|
||||
class Block(Module):
|
||||
|
||||
def __init__(self, config,layer_id):
|
||||
class Block(Module):
|
||||
"""
|
||||
## RWKV block element
|
||||
"""
|
||||
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
||||
self.attn = TimeMixing(config,layer_id)
|
||||
self.attn = TimeMixing(config, layer_id)
|
||||
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
||||
self.ffn = ChannelMixing(config,layer_id)
|
||||
self.ffn = ChannelMixing(config, layer_id)
|
||||
|
||||
def forward(self, x, state = None):
|
||||
def forward(self, x, state=None):
|
||||
# state: [batch_size, 5 , n_embd]
|
||||
"""
|
||||
## time mixing
|
||||
"""
|
||||
|
||||
# time mixing
|
||||
residual = x
|
||||
x,state = self.attn(self.ln_1(x),state=state)
|
||||
x, state = self.attn(self.ln_1(x), state=state)
|
||||
x = x + residual
|
||||
"""
|
||||
## channel mixing
|
||||
"""
|
||||
|
||||
# channel mixing
|
||||
residual = x
|
||||
x, state = self.ffn(self.ln_2(x),state=state)
|
||||
x, state = self.ffn(self.ln_2(x), state=state)
|
||||
x = x + residual
|
||||
return x, state
|
||||
|
||||
|
||||
class RWKV(Module):
|
||||
def __init__(self, config,lr_init=0.0008):
|
||||
"""
|
||||
## RWKV
|
||||
"""
|
||||
def __init__(self, config, lr_init=0.0008):
|
||||
super().__init__()
|
||||
assert config.vocab_size is not None
|
||||
assert config.block_size is not None
|
||||
self.config = config
|
||||
self.lr_init = lr_init ## used to initialize embedding parameters
|
||||
self.lr_init = lr_init ## used to initialize embedding parameters
|
||||
self.n_layer = config.n_layer
|
||||
self.n_embd = config.n_embd
|
||||
"""
|
||||
## Initiate model layers
|
||||
"""
|
||||
self.rwkv = nn.ModuleDict(dict(
|
||||
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
||||
ln_p = LayerNorm(config.n_embd, bias=config.bias),
|
||||
h = nn.ModuleList([Block(config,layer_id) for layer_id in range(config.n_layer)]),
|
||||
ln_f = LayerNorm(config.n_embd, bias=config.bias),
|
||||
))
|
||||
"""
|
||||
## Output linear layer
|
||||
"""
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
# Initiate model layers
|
||||
self.rwkv = nn.ModuleDict(dict(
|
||||
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
||||
ln_p=LayerNorm(config.n_embd, bias=config.bias),
|
||||
h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
|
||||
ln_f=LayerNorm(config.n_embd, bias=config.bias),
|
||||
))
|
||||
|
||||
# Output linear layer
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
def forward(self, idx, targets=None, state=None, return_state=False):
|
||||
b, t = idx.size()
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
||||
"""
|
||||
## Embedding Layer
|
||||
"""
|
||||
# Embedding Layer
|
||||
x = self.rwkv.wte(idx)
|
||||
"""
|
||||
## Layer Norm
|
||||
"""
|
||||
|
||||
# Layer Norm
|
||||
x = self.rwkv.ln_p(x)
|
||||
"""
|
||||
## RWKV Blocks
|
||||
"""
|
||||
for block_idx,block in enumerate(self.rwkv.h):
|
||||
x, state = block(x,state)
|
||||
|
||||
# RWKV Blocks
|
||||
for block_idx, block in enumerate(self.rwkv.h):
|
||||
x, state = block(x, state)
|
||||
x = self.rwkv.ln_f(x)
|
||||
|
||||
"""
|
||||
## Logit Layer and loss Function (for training)
|
||||
"""
|
||||
# Logit Layer and loss Function (for training)
|
||||
if targets is not None:
|
||||
# if we are given some desired targets also calculate the loss
|
||||
logits = self.lm_head(x)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
if self.training:
|
||||
loss = L2Wrap.apply(loss,logits)
|
||||
loss = L2Wrap.apply(loss, logits)
|
||||
else:
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
loss = None
|
||||
"""
|
||||
## Return Logits and loss
|
||||
"""
|
||||
|
||||
# Return Logits and loss
|
||||
if return_state:
|
||||
return logits, loss, state
|
||||
else:
|
||||
|
@ -1,15 +1,8 @@
|
||||
import copy
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from labml.configs import BaseConfigs, option, calculate, aggregate
|
||||
from labml_helpers.module import Module
|
||||
from labml.configs import BaseConfigs
|
||||
|
||||
|
||||
class RWKVConfigs(BaseConfigs):
|
||||
"""
|
||||
<a id="TransformerConfigs"></a>
|
||||
|
||||
## Transformer Configurations
|
||||
|
||||
This defines configurations for a transformer.
|
||||
|
@ -1,22 +1,16 @@
|
||||
import math,time
|
||||
import os
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
from labml_nn.RWKV.configs import RWKVConfigs
|
||||
|
||||
from __init__ import RWKV
|
||||
from __init__ import TimeMixing
|
||||
from labml import experiment
|
||||
from labml.configs import option
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
from configs import RWKVConfigs
|
||||
from __init__ import RWKV
|
||||
from __init__ import TimeMixing
|
||||
from __init__ import ChannelMixing
|
||||
from contextlib import nullcontext
|
||||
|
||||
|
||||
class Configs(NLPAutoRegressionConfigs):
|
||||
"""
|
||||
@ -41,6 +35,7 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
beta2: float = 0.95
|
||||
optimizer = 'rwkv_optimizer'
|
||||
|
||||
|
||||
@option(Configs.rwkv, 'RWKV')
|
||||
def _rwkv_configs(c: Configs):
|
||||
"""
|
||||
@ -58,9 +53,8 @@ def _rwkv_configs(c: Configs):
|
||||
|
||||
|
||||
def _init_weights(module, rwkv: RWKVConfigs):
|
||||
|
||||
## initialize Vector Parameters in TimeMixing
|
||||
if isinstance(module,TimeMixing):
|
||||
# initialize Vector Parameters in TimeMixing
|
||||
if isinstance(module, TimeMixing):
|
||||
layer_id = module.layer_id
|
||||
n_layer = module.n_layer
|
||||
n_embd = module.n_embd
|
||||
@ -85,7 +79,6 @@ def _init_weights(module, rwkv: RWKVConfigs):
|
||||
module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def _model(c: Configs):
|
||||
"""
|
||||
@ -101,30 +94,30 @@ def _model(c: Configs):
|
||||
|
||||
@option(NLPAutoRegressionConfigs.optimizer)
|
||||
def _configure_optimizers(c: NLPAutoRegressionConfigs):
|
||||
# start with all of the candidate parameters
|
||||
param_dict = {pn: p for pn, p in c.model.named_parameters()}
|
||||
# filter out those that do not require grad
|
||||
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
||||
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
||||
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
||||
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
||||
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
||||
optim_groups = [
|
||||
{'params': decay_params, 'weight_decay': c.weight_decay},
|
||||
{'params': nodecay_params, 'weight_decay': 0.0}
|
||||
]
|
||||
num_decay_params = sum(p.numel() for p in decay_params)
|
||||
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
||||
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
||||
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
||||
# Create AdamW optimizer and use the fused version if it is available
|
||||
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
||||
use_fused = fused_available and c.device_type == 'cuda'
|
||||
extra_args = dict(fused=True) if use_fused else dict()
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
|
||||
print(f"using fused AdamW: {use_fused}")
|
||||
# start with all of the candidate parameters
|
||||
param_dict = {pn: p for pn, p in c.model.named_parameters()}
|
||||
# filter out those that do not require grad
|
||||
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
||||
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
||||
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
||||
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
||||
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
||||
optim_groups = [
|
||||
{'params': decay_params, 'weight_decay': c.weight_decay},
|
||||
{'params': nodecay_params, 'weight_decay': 0.0}
|
||||
]
|
||||
num_decay_params = sum(p.numel() for p in decay_params)
|
||||
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
||||
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
||||
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
||||
# Create AdamW optimizer and use the fused version if it is available
|
||||
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
||||
use_fused = fused_available and c.device_type == 'cuda'
|
||||
extra_args = dict(fused=True) if use_fused else dict()
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
|
||||
print(f"using fused AdamW: {use_fused}")
|
||||
|
||||
return optimizer
|
||||
return optimizer
|
||||
|
||||
|
||||
def main():
|
||||
@ -154,11 +147,11 @@ def main():
|
||||
# per epoch
|
||||
'inner_iterations': 10,
|
||||
|
||||
'rwkv.block_size' : 1024,
|
||||
'rwkv.block_size': 1024,
|
||||
# model
|
||||
'rwkv.n_layer' : 12,
|
||||
'rwkv.n_heads' : 12,
|
||||
'rwkv.n_embd' : 768
|
||||
'rwkv.n_layer': 12,
|
||||
'rwkv.n_heads': 12,
|
||||
'rwkv.n_embd': 768
|
||||
})
|
||||
|
||||
print(conf.model)
|
||||
|
Reference in New Issue
Block a user