1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
| def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cache_params: Optional[TTTCache] = None, ): B, L = hidden_states.shape[:2] reminder_len = L % self.mini_batch_size num_mini_batch = L // self.mini_batch_size last_mini_batch_params_dict = None
XQ, XK, XV = self.get_qkv_projections(hidden_states, cache_params=cache_params)
XQ = XQ.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) XK = XK.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) XV = XV.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(XV, position_ids % self.mini_batch_size)
XQ, XK = permute_qk(XQ, XK) XQ, XK = apply_rotary_pos_emb(XQ, XK, cos, sin) XQ, XK = undo_permute_qk(XQ, XK)
output_hidden_states = [] if num_mini_batch > 0: inputs = { "XQ": XQ[:, :, : num_mini_batch * self.mini_batch_size], "XK": XK[:, :, : num_mini_batch * self.mini_batch_size], "XV": XV[:, :, : num_mini_batch * self.mini_batch_size], "X": hidden_states[:, : num_mini_batch * self.mini_batch_size], } output_mod, last_mini_batch_params_dict = self.ttt( self.get_ttt_inputs(inputs, self.mini_batch_size, cache_params), mini_batch_size=self.mini_batch_size, last_mini_batch_params_dict=last_mini_batch_params_dict, cache_params=cache_params, ) output_hidden_states.append(output_mod) if reminder_len > 0: inputs = { "XQ": XQ[:, :, -reminder_len:], "XK": XK[:, :, -reminder_len:], "XV": XV[:, :, -reminder_len:], "X": hidden_states[:, -reminder_len:], } output_reminder, _ = self.ttt( self.get_ttt_inputs(inputs, reminder_len, cache_params), mini_batch_size=reminder_len, last_mini_batch_params_dict=last_mini_batch_params_dict, cache_params=cache_params, ) output_hidden_states.append(output_reminder)
output_hidden_states = torch.cat(output_hidden_states, dim=1) output_hidden_states = self.post_norm(output_hidden_states) if self.use_gate: output_hidden_states = self.apply_gate(hidden_states, output_hidden_states) output_hidden_states = self.o_proj(output_hidden_states)
return output_hidden_states
|