1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
| class RWKV_TimeMix(torch.jit.ScriptModule): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id self.ctx_len = config.ctx_len self.n_embd = config.n_embd
attn_sz = config.n_embd
with torch.no_grad(): ratio_0_to_1 = (layer_id / (config.n_layer - 1)) ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) decay_speed = torch.ones(attn_sz) for h in range(attn_sz): decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1) self.time_decay = nn.Parameter(decay_speed)
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5) self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) x = torch.ones(1, 1, config.n_embd) for i in range(config.n_embd): x[0, 0, i] = i / config.n_embd self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(config.n_embd, attn_sz, bias=False) self.value = nn.Linear(config.n_embd, attn_sz, bias=False) self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0 self.receptance.scale_init = 0 self.output.scale_init = 0
@torch.jit.script_method def jit_func(self, x):
xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk) v = self.value(xv) r = self.receptance(xr) sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x): B, T, C = x.size()
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) rwkv = self.output(rwkv) return rwkv
|