传送门
模型解读
Mamba系列日积月累(一):状态空间模型SSM的离散化过程推导
连连看
首先从用法开始入手:
1 | import torch |
IDE会告诉我们:
入口在mamba_ssm/modules/mamba_simple.py
中.
接下来就来会会它
出发!
先来看看Mamba.forward()
,
很容易就能看到Mamba.step()
是关键点
1 | def forward(self, hidden_states, inference_params=None): |
Mamba.step()
完整实现了Figure3, 为了和源码中的注释区分,
混合了C风格的注释(#//
, 下同)
点我查看完整代码
1 | def step(self, hidden_states, conv_state, ssm_state): |
默认的selective_state_update
实现,
位于mamba_ssm/ops/triton/selective_state_update.py
中,
因为使用了一些优化手段, 阅读起来比较复杂,
但是作者给了方便阅读的对应的_ref
代码, 这里就不赘述了.
再回到Mamba.forward()
,
可以知道位于mamba_ssm/ops/selective_scan_interface.py
中的SelectiveScanFn
和MambaInnerFn
是实际运行时采用的代码,
作者同样提供了对应的_ref
版本, 基本逻辑一致, 直接跳过了
总结
Mamba.step()
中已经有了S6的计算, 为何又在selective_scan_fn
中重写了操作?- 有些代码细节和论文不一致, 以后读论文和代码要注意
A
矩阵为什么要取log再存储, 还原的时候又要加负号?