@@ -152,6 +152,7 @@ def build(
152152 dcp_rank ]
153153 actual_seq_lengths_kv = torch .cumsum (
154154 local_chunked_kv_lens_rank , dim = 0 ).tolist ()
155+ local_total_toks = local_chunked_kv_lens_rank .sum ()
155156 chunked_req_mask = self ._get_chunked_req_mask (
156157 local_context_lens_allranks )
157158 local_chunk_starts = torch .zeros (
@@ -181,7 +182,8 @@ def build(
181182 cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk ,
182183 kv_inverse_idx_for_chunk = kv_inverse_idx_for_chunk ,
183184 batch_chunk_seq_mask = batch_chunk_seq_mask ,
184- chunk_seq_mask_filtered_indices = chunk_seq_mask_filtered_indices
185+ chunk_seq_mask_filtered_indices = chunk_seq_mask_filtered_indices ,
186+ local_total_toks = local_total_toks .item ()
185187 )
186188 attn_mask_seqlens = common_long_seq_metadata .attn_mask_seqlens
187189 head_attn_nomask_seqlens = common_long_seq_metadata .head_attn_nomask_seqlens
@@ -672,7 +674,7 @@ def _compute_prefill_context(self, query: torch.Tensor,
672674
673675 local_chunked_kv_lens_rank = local_chunked_kv_lens [:, self .pcp_rank ,
674676 self .dcp_rank ]
675- total_toks = local_chunked_kv_lens_rank . sum ()
677+ total_toks = prefill_metadata . chunked_context . local_total_toks
676678
677679 key , value = self ._load_kv_for_chunk (attn_metadata , kv_cache ,
678680 local_chunked_kv_lens_rank , query ,
0 commit comments