钾肥喵的窝

我在 CODING 部署的 Hexo 博客

0%

代码连连看——TTT

传送门

arxiv

github

模型解读

我们知道Transformer的复杂度是平方的, 就很丑; 而RNN的复杂度虽然是线性的, 但是因为隐状态的大小是固定受限的, 长序列处理一直是弱项. 用可训练模型作为隐状态对RNN进行魔改就有了TTT.

连连看

Quick Start是个好东西

config 先跳过, 顺着 TTTForCausalLM 一路找下去, 先看 __init__()

1
2
3
4
5
6
7
8
def __init__(self, config):
super().__init__(config)
self.model = TTTModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()

很明显, 重要的是TTTModel, 同样是先看 __init__()

1
2
3
4
5
6
7
8
9
10
11
12
def __init__(self, config: TTTConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False

# Initialize weights and apply final processing
self.post_init()

Block 接着DFS, 可以发现有两个分支: TTTLinearTTTMLP, 它们都继承了 TTTBase, 先看最重要的forward():

实现细节先跳过, 跳到 ttt 的具体实现, 先看 TTTLinear 里的实现:

继续研究 compute_mini_batch() 是干什么的

TTTMLP 实际上就是两层 TTTLinear 叠叠乐, 这里就不做了