Skip to content

Commit f274cfb

Browse files
author
F.Liu
committed
[pref] calculate local_total_toks in build meata
Signed-off-by: F.Liu <liufeng248@huawei.com>
1 parent f9dad48 commit f274cfb

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

vllm_ascend/attention/attention_cp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def build(
152152
dcp_rank]
153153
actual_seq_lengths_kv = torch.cumsum(
154154
local_chunked_kv_lens_rank, dim=0).tolist()
155+
local_total_toks = local_chunked_kv_lens_rank.sum()
155156
chunked_req_mask = self._get_chunked_req_mask(
156157
local_context_lens_allranks)
157158
local_chunk_starts = torch.zeros(
@@ -181,7 +182,8 @@ def build(
181182
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
182183
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
183184
batch_chunk_seq_mask=batch_chunk_seq_mask,
184-
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices
185+
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices,
186+
local_total_toks=local_total_toks.item()
185187
)
186188
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
187189
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
@@ -672,7 +674,7 @@ def _compute_prefill_context(self, query: torch.Tensor,
672674

673675
local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank,
674676
self.dcp_rank]
675-
total_toks = local_chunked_kv_lens_rank.sum()
677+
total_toks = prefill_metadata.chunked_context.local_total_toks
676678

677679
key, value = self._load_kv_for_chunk(attn_metadata, kv_cache,
678680
local_chunked_kv_lens_rank, query,

vllm_ascend/attention/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class ChunkedContextMetadata:
6363
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
6464
kv_inverse_idx_for_chunk: Optional[list[int]] = None
6565
batch_chunk_seq_mask: Optional[list[bool]] = None
66+
local_total_toks: int = None
6667

6768
""" Prefill Specific Metadata for Ascend"""
6869
pcp_metadata: Optional[AscendPCPMetadata] = None

0 commit comments

Comments
 (0)