钾肥喵的窝

我在 CODING 部署的 Hexo 博客

0%

代码连连看——RWKV

传送门

arxiv

github

模型解读

Huggingface rwkv代码解读

RWKV 模型解析

连连看

先用v4版本做连连看, v5和v6对应arxiv:2404.05892, 有时间再做连连看

老规矩, 找入口(RWKV-v4/run.py):

1
2
3
4
print(f'Loading {MODEL_NAME}...')
from src.model_run import RWKV_RNN
model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)

后面调用的是 model.run(x), 很明显, 我们要进到 RWKV-v4/src/model_run.py 里面去

出发!

注释和mamba一样, 都是用 #// 表示

首先直奔 RWKV_RNN.run(), 看到下面这句话就知道还要继续找下去

1
w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))

看看同文件夹的有什么文件就知道该去哪里找了: RWKV-v4/src/model.py

先找最顶层的模块:

接着来看看Block是怎么实现的

先看看 RWKV_TimeMix

里面的 RUN_CUDA() 实际上是 class WKV 套娃, 而 class WKV 也是套娃, 真正的实现是用CUDA写的, 在 RWKV-v4/cuda/wkv_cuda.cuRWKV-v4/cuda/wkv_op.cpp

接下来看看 ChannelMix

是不是忘了什么? 对, 还有一步很重要的 States 的传递没有体现, 因为我们看的不是 RNN 形式, 让我们回到 RWKV_RNN, States 的传递是通过 save()load() 实现的, 其他代码和上面的基本一致, 就不赘述了

总结