Receptance Weighted Key Value (RWKV)

This is a tutorial/implementation of RWKV from paper RWKV: Reinventing RNNs for the Transformer Era in PyTorch.

Full definition of a RWKV Language Model, all of it in this single file. References: 1) the official RWKV PyTorch implementation released by Bo Peng 2) huggingface/transformers PyTorch implementation

22import torch
23import torch.nn as nn
24from torch.nn import functional as F
25
26
27PREV_X_TIME = 0
28NUM_STATE = 1
29DEN_STATE = 2
30MAX_STATE = 3
31PREV_X_CHANNEL = 4

Layer normalization with bias

34class LayerNorm(nn.Module):
39    def __init__(self, ndim, bias):
40        super().__init__()
41        self.weight = nn.Parameter(torch.ones(ndim))
42        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
44    def forward(self, input):
45        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

L2 loss wrapper

ref

48class L2Wrap(torch.autograd.Function):
55    @staticmethod
56    def forward(ctx, loss, y):
57        ctx.save_for_backward(y)
58        return loss
59
60    @staticmethod
61    def backward(ctx, grad_output):
62        y = ctx.saved_tensors[0]

to encourage the logits to be close to 0

64        factor = 1e-4 / (y.shape[0] * y.shape[1])
65        maxx, ids = torch.max(y, -1, keepdim=True)
66        gy = torch.zeros_like(y)
67        gy.scatter_(-1, ids, maxx * factor)
68        return grad_output, gy

Channel Mixing

71class ChannelMixing(nn.Module):
76    def __init__(self, config, layer_id):
77        super().__init__()
78        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

token shifting

80        self.layer_id = layer_id
81
82        n_embd = config.n_embd
83        intermediate_size = (
84            config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
85        )

Learnable Matrix

88        self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
89        self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
90        self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)

Learnable Vector

93        self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
94        self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))

x = (Batch,Time,Channel)

96    def forward(self, x, state=None):
100        if state is not None:
101            prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
102            state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
103        else:
104            prev_x = self.time_shift(x)

107        receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
108        receptance = self.receptance_proj(receptance)

111        key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
112        key = self.key_proj(key)

115        value = self.value_proj(torch.square(torch.relu(key)))

118        out = F.sigmoid(receptance) * value
119        return out, state

Time Mixing

122class TimeMixing(nn.Module):
127    def __init__(self, config, layer_id):
128        super().__init__()
129        self.config = config
130        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
131        self.layer_id = layer_id
132
133        n_embd = config.n_embd
134        attn_sz = n_embd

learnable matrix

137        self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
138        self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
139        self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
140        self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)

learnable vector

143        self.time_decay = nn.Parameter(torch.empty(attn_sz))
144        self.time_first = nn.Parameter(torch.empty(attn_sz))
145        self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
146        self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
147        self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))

x = (Batch,Time,Channel)

149    def forward(self, x, state=None):
153        if state is not None:
154            prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
155            state[self.layer_id, :, [PREV_X_TIME], :] = x
156        else:
157            prev_x = self.time_shift(x)

160        receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
161        receptance = self.receptance_proj(receptance)

164        key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
165        key = self.key_proj(key)

168        value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
169        value = self.value_proj(value)

WKV calculation

172        _, seq_length, _ = key.size()
173        output = torch.zeros_like(key)
174
175        if state is None:
176            num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
177            den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
178            max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
179        else:
180            num_state = state[self.layer_id, :, NUM_STATE, :]
181            den_state = state[self.layer_id, :, DEN_STATE, :]
182            max_state = state[self.layer_id, :, MAX_STATE, :]
183
184        time_decay = -torch.exp(self.time_decay)
185
186        for current_index in range(seq_length):
187            current_key = key[:, current_index].float()
188            current_value = value[:, current_index]

191            max_for_output = torch.maximum(max_state, current_key + self.time_first)
192            e1 = torch.exp(max_state - max_for_output)
193            e2 = torch.exp(current_key + self.time_first - max_for_output)
194            numerator = e1 * num_state + e2 * current_value
195            denominator = e1 * den_state + e2
196            output[:, current_index] = (numerator / denominator).to(output.dtype)

Update state for next iteration

199            max_for_state = torch.maximum(max_state + time_decay, current_key)
200            e1 = torch.exp(max_state + time_decay - max_for_state)
201            e2 = torch.exp(current_key - max_for_state)
202            num_state = e1 * num_state + e2 * current_value
203            den_state = e1 * den_state + e2
204            max_state = max_for_state

update states

207        state[self.layer_id, :, NUM_STATE, :] = num_state
208        state[self.layer_id, :, DEN_STATE, :] = den_state
209        state[self.layer_id, :, MAX_STATE, :] = max_state
210        wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
211                                       state=state)

214        rwkv = F.sigmoid(receptance) * wkv
215        rwkv = self.output_proj(rwkv)
216
217        return rwkv, state

RWKV block element

220class Block(nn.Module):
225    def __init__(self, config, layer_id):
226        super().__init__()
227        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
228        self.attn = TimeMixing(config, layer_id)
229        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
230        self.ffn = ChannelMixing(config, layer_id)
232    def forward(self, x, state=None):

time mixing

236        residual = x
237        x, state = self.attn(self.ln_1(x), state=state)
238        x = x + residual

channel mixing

241        residual = x
242        x, state = self.ffn(self.ln_2(x), state=state)
243        x = x + residual
244        return x, state

RWKV

247class RWKV(nn.Module):
251    def __init__(self, config, lr_init=0.0008):
252        super().__init__()
253        assert config.vocab_size is not None
254        assert config.block_size is not None
255        self.config = config
256        self.lr_init = lr_init  ## used to initialize embedding parameters
257        self.n_layer = config.n_layer
258        self.n_embd = config.n_embd

Initiate model layers

261        self.rwkv = nn.ModuleDict(dict(
262            wte=nn.Embedding(config.vocab_size, config.n_embd),
263            ln_p=LayerNorm(config.n_embd, bias=config.bias),
264            h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
265            ln_f=LayerNorm(config.n_embd, bias=config.bias),
266        ))

Output linear layer

269        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
271    def forward(self, idx, targets=None, state=None, return_state=False):
272        b, t = idx.size()
273        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

Embedding Layer

276        x = self.rwkv.wte(idx)

Layer Norm

279        x = self.rwkv.ln_p(x)

RWKV Blocks

282        for block_idx, block in enumerate(self.rwkv.h):
283            x, state = block(x, state)
284        x = self.rwkv.ln_f(x)

Logit Layer and loss Function (for training)

287        if targets is not None:

if we are given some desired targets also calculate the loss

289            logits = self.lm_head(x)
290            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
291            if self.training:
292                loss = L2Wrap.apply(loss, logits)
293        else:

inference-time mini-optimization: only forward the lm_head on the very last position

295            logits = self.lm_head(x[:, [-1], :])  # note: using list [-1] to preserve the time dim
296            loss = None

Return Logits and loss

299        if return_state:
300            return logits, loss, state
301        else:
302            return logits, loss