Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6ae53b5
integrate FIA operator into mla_cp
Dec 24, 2025
08de021
make it more readable
Dec 29, 2025
048b04f
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Dec 30, 2025
daafaff
adapt acl_graph in mla_cp FIA
Dec 31, 2025
cab49ba
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Dec 31, 2025
452c663
adapt graph mode
Jan 5, 2026
6733ce3
support mtp
Jan 6, 2026
3650848
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Jan 6, 2026
410be4d
remove redundant attributes
Jan 6, 2026
8d06f81
remove data cleaning
Jan 6, 2026
a974ca9
remove redundant variables after mla_cp forward decode uses fia
dsxsteven Jan 6, 2026
b901f38
delete redundant variables in mtp proposer
dsxsteven Jan 6, 2026
d33951f
add ut for arguments cp-kv-interleave-size
dsxsteven Jan 7, 2026
1352315
Update vllm_ascend/attention/context_parallel/mla_cp.py
845473182 Jan 7, 2026
de585e8
fix test name
dsxsteven Jan 7, 2026
fa520cd
mv ut after rebase
dsxsteven Jan 7, 2026
3fc5720
mv ut after rebase
dsxsteven Jan 7, 2026
f5f5b0a
fix pre-commit
dsxsteven Jan 7, 2026
47072e3
fix lint
Jan 7, 2026
120ac20
Merge branch 'FIA_rebase' of https://github.com/845473182/vllm-ascend…
Jan 7, 2026
7e899c6
fix lint
Jan 7, 2026
40afa15
fix lint
Jan 7, 2026
4134757
Merge branch 'main' into FIA_rebase
845473182 Jan 8, 2026
c3f5465
fix ut
Jan 8, 2026
b559ab0
Merge branch 'FIA_rebase' of https://github.com/845473182/vllm-ascend…
Jan 8, 2026
92436a2
fix lint
Jan 8, 2026
a2a6f72
[Ops] replace _update_out_and_lse with _npu_attn_out_lse_update
Jan 6, 2026
4a44ad4
Merge branch 'vllm-project:main' into main_2026_0106_remove_redunant_…
dsxsteven Jan 9, 2026
c01e0c2
Merge pull request #2 from 845473182/FIA_rebase
YzTongNiar Jan 9, 2026
445db54
Merge branch 'ops' into main_2026_0106_remove_redunant_var_after_fia
dsxsteven Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,72 @@ def test_accuracy_pcp_only(max_tokens: int, ) -> None:
name_0="vllm_eager_outputs",
name_1="vllm_pcp_only_outputs",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [10])
def test_cp_kv_cache_interleave_size_between_tp_and_cp(
model: str,
max_tokens: int,
) -> None:
prompts = [
"The president of the United States is", "The capital of France is"
]

common_kwargs = {
"max_model_len": 1024,
}

if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8":
cp_kwargs = {
"tensor_parallel_size": 2,
"decode_context_parallel_size": 2,
"prefill_context_parallel_size": 2,
"enable_expert_parallel": True,
"cp_kv_cache_interleave_size": 128,
"enforce_eager": True,
"quantization": "ascend",
}
tp_kwargs = {
"tensor_parallel_size": 4,
"enable_expert_parallel": True,
"enforce_eager": True,
"quantization": "ascend",
}

else:
cp_kwargs = {
"tensor_parallel_size": 1,
"decode_context_parallel_size": 1,
"prefill_context_parallel_size": 2,
"cp_kv_cache_interleave_size": 128,
"compilation_config": {
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
},
}
tp_kwargs = {
"tensor_parallel_size": 2,
"enforce_eager": True,
}

cp_full_kwargs = {}
cp_full_kwargs.update(common_kwargs) # type: ignore
cp_full_kwargs.update(cp_kwargs) # type: ignore

tp_full_kwargs = {}
tp_full_kwargs.update(common_kwargs) # type: ignore
tp_full_kwargs.update(tp_kwargs) # type: ignore
with VllmRunner(model, **cp_full_kwargs) as runner: # type: ignore
vllm_context_parallel_outputs = runner.generate_greedy(
prompts, max_tokens)

with VllmRunner(model, **tp_full_kwargs) as runner: # type: ignore
vllm_eager_outputs = runner.generate_greedy(prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=vllm_eager_outputs,
outputs_1_lst=vllm_context_parallel_outputs,
name_0="vllm_eager_outputs",
name_1="vllm_context_parallel_outputs",
)
77 changes: 37 additions & 40 deletions tests/ut/attention/test_attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def mock_npu_fused_infer_attention_score_func(query, k_nope, value,

attn_metadata = MagicMock()
attn_metadata.decode_meta = MagicMock()
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
[1, 0], dtype=torch.bool)
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)

self.assertEqual(output.shape[0], 2)
Expand All @@ -137,8 +135,10 @@ def test_prefill_query_all_gather(self):
self.assertEqual(output.shape[2], 128)

@patch('torch.ops.npu.npu_fused_infer_attention_score')
@patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_compute_prefill_context(self, mock_npu_attention):
def test_compute_prefill_context(self, mock_npu_attention_update,
mock_npu_attention):

block_num = 100
block_size = 128
Expand Down Expand Up @@ -181,7 +181,9 @@ def mock_load_kv_for_chunk(attn_metadata, kv_cache,
head_size), torch.randn(
batch_size,
num_heads, 1)

mock_npu_attention_update.return_value = torch.randn(
batch_size, self.impl.num_heads,
head_size), torch.randn(batch_size, self.impl.num_heads, 1)
context_output = self.impl._compute_prefill_context(
query, kv_cache, attn_metadata)
local_context_output = torch.cat(context_output,
Expand Down Expand Up @@ -406,11 +408,9 @@ def test_attention_with_nomask_none(self, mock_npu_attention):
self.assertEqual(attn_lse.shape, (96, 8, 1))

@patch('torch.ops.npu.npu_fused_infer_attention_score')
@patch(
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._update_out_and_lse'
)
@patch('torch_npu.npu_attention_update')
def test_attention_with_nomask_and_mask_chunk(
self, mock_update_out_and_lse,
self, mock_npu_attention_update,
mock_npu_fused_infer_attention_score):
# Mock input data
q = torch.randn(self.q_total_tokens, self.impl.num_heads,
Expand All @@ -432,7 +432,7 @@ def test_attention_with_nomask_and_mask_chunk(
self.q_total_tokens, self.impl.num_heads,
self.impl.head_size), torch.randn(self.q_total_tokens,
self.impl.num_heads, 1)
mock_update_out_and_lse.return_value = torch.randn(
mock_npu_attention_update.return_value = torch.randn(
self.q_total_tokens, self.impl.num_heads,
self.impl.head_size), torch.randn(self.q_total_tokens,
self.impl.num_heads, 1)
Expand Down Expand Up @@ -481,8 +481,12 @@ def test_attention_with_nomask_and_mask_nochunk(
self.q_total_tokens, self.impl.num_heads,
self.impl.head_size), torch.randn(self.q_total_tokens,
self.impl.num_heads, 1)
mock_npu_attn_out_lse_update.return_value = torch.randn(
self.q_total_tokens, self.impl.num_heads, self.impl.head_size)
mock_npu_attn_out_lse_update.return_value = (torch.randn(
self.q_total_tokens, self.impl.num_heads, self.impl.head_size),
torch.randn(
self.q_total_tokens,
self.impl.num_heads,
1))

# Call the method under test
output, attn_lse = self.impl._attention_with_nomask_and_mask(
Expand All @@ -500,7 +504,6 @@ def test_attention_with_nomask_and_mask_nochunk(
mock_npu_attn_out_lse_update.assert_called_once()
self.assertEqual(mock_npu_fused_infer_attention_score.call_count, 2)
self.assertIsNotNone(output)
self.assertEqual(attn_lse, None)

@patch(
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
Expand Down Expand Up @@ -550,43 +553,26 @@ def test_npu_attn_out_lse_update(self, mock_npu_attention_update):
attn_out_nomask = torch.randn(8, 128, 128)

# Mock output
mock_npu_attention_update.return_value = (torch.randn(8 * 128,
128), None)
mock_npu_attention_update.return_value = (torch.randn(8 * 128, 128),
torch.randn(8 * 128, 1))

# Call the method under test
output = self.impl._npu_attn_out_lse_update(attn_lse_mask,
attn_lse_nomask,
attn_out_mask,
attn_out_nomask)
output, _ = self.impl._npu_attn_out_lse_update(attn_lse_mask,
attn_lse_nomask,
attn_out_mask,
attn_out_nomask)

# Assert the method call
self.assertIsInstance(output, torch.Tensor)
self.assertEqual(output.shape, (8, 128, 128))

mock_npu_attention_update.assert_called_once()

def test_update_out_and_lse(self):
# Mock input data
out_list = torch.randn(3, 2, 4,
8) # [N, batch_size, num_heads, head_size]
lse_list = torch.randn(3, 2, 4, 1) # [N, batch_size, num_heads, 1]

# Call the method under test
out_final, lse_final = self.impl._update_out_and_lse(
out_list, lse_list)

# Assert the method call
self.assertEqual(out_final.shape,
(2, 4, 8)) # [batch_size, num_heads, head_size]
self.assertEqual(lse_final.shape,
(2, 4, 1)) # [batch_size, num_heads, 1]

self.assertIsInstance(out_final, torch.Tensor)
self.assertIsInstance(lse_final, torch.Tensor)

@patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=3)
def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single,
mock_dcp, mock_pcp):
mock_dcp, mock_pcp,
mock_npu_attention_update):
# Mock input data
prefix_chunk_output = torch.randn(2, 4, 8)
prefix_chunk_lse = torch.randn(2, 4, 1)
Expand All @@ -601,6 +587,8 @@ def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single,
chunk_data)
global_context_output = global_context_output.permute([2, 0, 1
]).contiguous()
mock_npu_attention_update.return_value = (torch.randn(2, 2, 8),
torch.randn(2, 2, 1))
output, lse = self.impl._update_global_context_output(
global_context_output)

Expand All @@ -613,9 +601,11 @@ def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single,
mock_all_to_all_single.assert_called_once()
mock_pcp.all_gather.assert_called_once()

@patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2)
def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single,
mock_dcp, mock_pcp):
mock_dcp, mock_pcp,
mock_npu_attention_update):
# Mock input data
prefix_chunk_output = torch.randn(2, 4, 8)
prefix_chunk_lse = torch.randn(2, 4, 1)
Expand All @@ -631,6 +621,8 @@ def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single,
chunk_data)
global_context_output = global_context_output.permute([2, 0, 1
]).contiguous()
mock_npu_attention_update.return_value = (torch.randn(2, 2, 8),
torch.randn(2, 2, 1))
output, lse = self.impl._update_global_context_output(
global_context_output)

Expand All @@ -643,9 +635,11 @@ def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single,
mock_all_to_all_single.assert_called_once()
mock_pcp.all_gather.assert_not_called()

@patch('torch_npu.npu_attention_update')
@patch_distributed_groups(pcp_size=2)
def test_update_chunk_attn_out_lse_dcp1_pcp2(self, mock_all_to_all_single,
mock_dcp, mock_pcp):
mock_dcp, mock_pcp,
mock_npu_attention_update):
# Mock input data
prefix_chunk_output = torch.randn(2, 4, 8)
prefix_chunk_lse = torch.randn(2, 4, 1)
Expand All @@ -661,6 +655,9 @@ def test_update_chunk_attn_out_lse_dcp1_pcp2(self, mock_all_to_all_single,
chunk_data)
global_context_output = global_context_output.permute([2, 0, 1
]).contiguous()
mock_npu_attention_update.return_value = torch.randn(2, 4,
8), torch.randn(
2, 4, 1)
output, lse = self.impl._update_global_context_output(
global_context_output)

Expand Down
23 changes: 7 additions & 16 deletions tests/ut/attention/test_mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,22 +439,19 @@ def test_process_attn_out_lse(self):
decode_metadata = MagicMock()
decode_metadata.actual_seq_lengths_q = MagicMock()
decode_metadata.seq_lens_list = MagicMock()
decode_metadata.batch_seq_mask = torch.tensor([True, False],
dtype=torch.bool)

result = _process_attn_out_lse(attn_output, softmax_lse,
decode_metadata.batch_seq_mask)
result = _process_attn_out_lse(attn_output, softmax_lse)

self.assertEqual(result.shape[0], B * self.impl.pcp_size)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)

@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention")
@patch("torch_npu.npu_fused_infer_attention_score")
@patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
mock_npu_multi_head_latent_attention,
mock_npu_fused_infer_attention_score,
mock_get_forward_context):
self.impl.dcp_size = 2
self.impl.pcp_size = 2
Expand All @@ -470,22 +467,20 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,

q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank)
k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim)
k_nope = torch.randn(NB, 1, BS, self.impl.kv_lora_rank)
k_pe = torch.randn(NB, 1, BS, self.impl.qk_rope_head_dim)

attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
attn_metadata.decode = MagicMock()
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
attn_metadata.decode.seq_lens_list = MagicMock()
attn_metadata.decode.batch_seq_mask = torch.tensor([False, False],
dtype=torch.bool)

self.impl.enable_kv_nz = True

mock_npu_attention_update.return_value = (torch.randn(
B, self.impl.num_heads, self.impl.kv_lora_rank), None)
mock_npu_multi_head_latent_attention.return_value = [
mock_npu_fused_infer_attention_score.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank),
torch.randn(B, N, 1)
]
Expand Down Expand Up @@ -886,12 +881,8 @@ def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_to_all, mock_dcp,
# Inputs
attn_output = torch.randn(B, H, D)
softmax_lse = torch.randn(B, H, 1)
batch_seq_mask = torch.tensor([False, True, False, False]) # [B]
decode_meta = MagicMock()
decode_meta.batch_seq_mask = batch_seq_mask

result = _process_attn_out_lse(attn_output, softmax_lse,
batch_seq_mask)
result = _process_attn_out_lse(attn_output, softmax_lse)
# [PCP * S, DCP * H, D + 1]
self.assertIsInstance(result, torch.Tensor)
assert result.shape == (B * self.impl.pcp_size, H, D + 1)
Expand Down
5 changes: 1 addition & 4 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,14 @@ def test_ascend_mla_decode_metadata_default(self):
seq_lens_list = [2, 3]
attn_mask = None
cp_seq_len = torch.tensor([2, 3])
batch_seq_mask = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 0]])

metadata = AscendMLADecodeMetadata(input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
max_seq_lens=max_seq_lens,
seq_lens_list=seq_lens_list,
attn_mask=attn_mask,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
cp_seq_len=cp_seq_len)

self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
Expand All @@ -139,7 +137,6 @@ def test_ascend_mla_decode_metadata_default(self):
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
self.assertIs(metadata.cp_seq_len, cp_seq_len)
self.assertIs(metadata.batch_seq_mask, batch_seq_mask)


class TestAscendMLAMetadata(TestBase):
Expand Down
12 changes: 8 additions & 4 deletions tests/ut/compilation/test_acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def setUp(self):

@patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock())
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
block_table = torch.zeros(2, 5, dtype=torch.long)
Expand Down Expand Up @@ -793,16 +793,20 @@ def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
qk_rope_head_dim = 32
qk_nope_head_dim = 64
query = torch.randn(4, num_heads, qk_head_dim)
q_pe = query[..., qk_nope_head_dim:]

q_nope = query[..., :qk_nope_head_dim]
q_pe = query[..., qk_rope_head_dim:]
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
input_layout = "BNSD"
actual_seq_lengths_kv = [1, 1]
out = torch.randn(2, 16, 128)
lse = torch.randn(2, 16, 8)
self.graph_params.attn_params[4] = []
self.graph_params.attn_params[4].append(
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads,
scale, num_kv_heads, out, lse))
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
None, 0, scale, block_table, 128, None, actual_seq_lengths_kv,
out, lse))

with patch("torch_npu._C._npu_setStream", return_value=None):
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
Expand Down
Loading
Loading