钾肥喵的窝

我在 CODING 部署的 Hexo 博客

0%

代码连连看——mamba

传送门

arxiv

github

模型解读

【汇报】 Mamba模型及其公式推导

Mamba系列日积月累(一):状态空间模型SSM的离散化过程推导

连连看

首先从用法开始入手:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

IDE会告诉我们: 入口在mamba_ssm/modules/mamba_simple.py中. 接下来就来会会它

出发!

先来看看Mamba.forward(), 很容易就能看到Mamba.step()是关键点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape

conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
...

Mamba.step()完整实现了Figure3, 为了和源码中的注释区分, 混合了C风格的注释(#//, 下同)

默认的selective_state_update实现, 位于mamba_ssm/ops/triton/selective_state_update.py中, 因为使用了一些优化手段, 阅读起来比较复杂, 但是作者给了方便阅读的对应的_ref代码, 这里就不赘述了.

再回到Mamba.forward(), 可以知道位于mamba_ssm/ops/selective_scan_interface.py中的SelectiveScanFnMambaInnerFn是实际运行时采用的代码, 作者同样提供了对应的_ref版本, 基本逻辑一致, 直接跳过了

总结

  1. Mamba.step()中已经有了S6的计算, 为何又在selective_scan_fn中重写了操作?
  2. 有些代码细节和论文不一致, 以后读论文和代码要注意
  3. A矩阵为什么要取log再存储, 还原的时候又要加负号?