This is a tutorial/implementation of RWKV from paper RWKV: Reinventing RNNs for the Transformer Era in PyTorch.
Full definition of a RWKV Language Model, all of it in this single file. References: 1) the official RWKV PyTorch implementation released by Bo Peng 2) huggingface/transformers PyTorch implementation
22import torch
23import torch.nn as nn
24from torch.nn import functional as F
25
26
27PREV_X_TIME = 0
28NUM_STATE = 1
29DEN_STATE = 2
30MAX_STATE = 3
31PREV_X_CHANNEL = 4
34class LayerNorm(nn.Module):
39 def __init__(self, ndim, bias):
40 super().__init__()
41 self.weight = nn.Parameter(torch.ones(ndim))
42 self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
44 def forward(self, input):
45 return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
55 @staticmethod
56 def forward(ctx, loss, y):
57 ctx.save_for_backward(y)
58 return loss
59
60 @staticmethod
61 def backward(ctx, grad_output):
62 y = ctx.saved_tensors[0]
to encourage the logits to be close to 0
64 factor = 1e-4 / (y.shape[0] * y.shape[1])
65 maxx, ids = torch.max(y, -1, keepdim=True)
66 gy = torch.zeros_like(y)
67 gy.scatter_(-1, ids, maxx * factor)
68 return grad_output, gy
71class ChannelMixing(nn.Module):
76 def __init__(self, config, layer_id):
77 super().__init__()
78 self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
token shifting
80 self.layer_id = layer_id
81
82 n_embd = config.n_embd
83 intermediate_size = (
84 config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
85 )
Learnable Matrix
88 self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
89 self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
90 self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)
Learnable Vector
93 self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
94 self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
96 def forward(self, x, state=None):
100 if state is not None:
101 prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
102 state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
103 else:
104 prev_x = self.time_shift(x)
107 receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
108 receptance = self.receptance_proj(receptance)
111 key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
112 key = self.key_proj(key)
115 value = self.value_proj(torch.square(torch.relu(key)))
118 out = F.sigmoid(receptance) * value
119 return out, state
122class TimeMixing(nn.Module):
127 def __init__(self, config, layer_id):
128 super().__init__()
129 self.config = config
130 self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
131 self.layer_id = layer_id
132
133 n_embd = config.n_embd
134 attn_sz = n_embd
learnable matrix
137 self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
138 self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
139 self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
140 self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
learnable vector
143 self.time_decay = nn.Parameter(torch.empty(attn_sz))
144 self.time_first = nn.Parameter(torch.empty(attn_sz))
145 self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
146 self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
147 self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
x = (Batch,Time,Channel)
149 def forward(self, x, state=None):
153 if state is not None:
154 prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
155 state[self.layer_id, :, [PREV_X_TIME], :] = x
156 else:
157 prev_x = self.time_shift(x)
160 receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
161 receptance = self.receptance_proj(receptance)
164 key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
165 key = self.key_proj(key)
168 value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
169 value = self.value_proj(value)
WKV calculation
172 _, seq_length, _ = key.size()
173 output = torch.zeros_like(key)
174
175 if state is None:
176 num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
177 den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
178 max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
179 else:
180 num_state = state[self.layer_id, :, NUM_STATE, :]
181 den_state = state[self.layer_id, :, DEN_STATE, :]
182 max_state = state[self.layer_id, :, MAX_STATE, :]
183
184 time_decay = -torch.exp(self.time_decay)
185
186 for current_index in range(seq_length):
187 current_key = key[:, current_index].float()
188 current_value = value[:, current_index]
191 max_for_output = torch.maximum(max_state, current_key + self.time_first)
192 e1 = torch.exp(max_state - max_for_output)
193 e2 = torch.exp(current_key + self.time_first - max_for_output)
194 numerator = e1 * num_state + e2 * current_value
195 denominator = e1 * den_state + e2
196 output[:, current_index] = (numerator / denominator).to(output.dtype)
Update state for next iteration
199 max_for_state = torch.maximum(max_state + time_decay, current_key)
200 e1 = torch.exp(max_state + time_decay - max_for_state)
201 e2 = torch.exp(current_key - max_for_state)
202 num_state = e1 * num_state + e2 * current_value
203 den_state = e1 * den_state + e2
204 max_state = max_for_state
update states
207 state[self.layer_id, :, NUM_STATE, :] = num_state
208 state[self.layer_id, :, DEN_STATE, :] = den_state
209 state[self.layer_id, :, MAX_STATE, :] = max_state
210 wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
211 state=state)
214 rwkv = F.sigmoid(receptance) * wkv
215 rwkv = self.output_proj(rwkv)
216
217 return rwkv, state
220class Block(nn.Module):
225 def __init__(self, config, layer_id):
226 super().__init__()
227 self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
228 self.attn = TimeMixing(config, layer_id)
229 self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
230 self.ffn = ChannelMixing(config, layer_id)
232 def forward(self, x, state=None):
state: batch_size, 5 , n_embd
time mixing
236 residual = x
237 x, state = self.attn(self.ln_1(x), state=state)
238 x = x + residual
channel mixing
241 residual = x
242 x, state = self.ffn(self.ln_2(x), state=state)
243 x = x + residual
244 return x, state
247class RWKV(nn.Module):
251 def __init__(self, config, lr_init=0.0008):
252 super().__init__()
253 assert config.vocab_size is not None
254 assert config.block_size is not None
255 self.config = config
256 self.lr_init = lr_init ## used to initialize embedding parameters
257 self.n_layer = config.n_layer
258 self.n_embd = config.n_embd
Initiate model layers
261 self.rwkv = nn.ModuleDict(dict(
262 wte=nn.Embedding(config.vocab_size, config.n_embd),
263 ln_p=LayerNorm(config.n_embd, bias=config.bias),
264 h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
265 ln_f=LayerNorm(config.n_embd, bias=config.bias),
266 ))
Output linear layer
269 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
271 def forward(self, idx, targets=None, state=None, return_state=False):
272 b, t = idx.size()
273 assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
Embedding Layer
276 x = self.rwkv.wte(idx)
Layer Norm
279 x = self.rwkv.ln_p(x)
RWKV Blocks
282 for block_idx, block in enumerate(self.rwkv.h):
283 x, state = block(x, state)
284 x = self.rwkv.ln_f(x)
Logit Layer and loss Function (for training)
287 if targets is not None:
if we are given some desired targets also calculate the loss
289 logits = self.lm_head(x)
290 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
291 if self.training:
292 loss = L2Wrap.apply(loss, logits)
293 else:
inference-time mini-optimization: only forward the lm_head on the very last position
295 logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
296 loss = None
Return Logits and loss
299 if return_state:
300 return logits, loss, state
301 else:
302 return logits, loss