Skip to content

Conversation

@ForBetterCodeNine
Copy link
Contributor

@ForBetterCodeNine ForBetterCodeNine commented Jan 4, 2026

What this PR does / why we need it?

This PR builds upon PR #5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for SPPatternWithBias and Qwen3-32B model for SPPattern. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models.
For more details can refer to the RFC:#4715

Does this PR introduce any user-facing change?

no

How was this patch tested?

llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two new fusion patterns, AddRMSNormSPPattern and AddRMSNormSPPatternWithBias. However, the implementation contains several critical bugs that prevent it from working correctly. Both new functions have Python syntax errors due to incorrect indentation, flawed logic for stream checking that would prevent fusions, and most importantly, they are missing the call to register the fusion patterns, making them completely ineffective. I have provided two comprehensive review comments with code suggestions that fix all these issues for both functions.

Comment on lines 216 to 219
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
if 'torch_npu' not in sys.modules:
logger.info(
'The AddRMSNormQuantSPPattern fusion will only be enabled in a torch npu env.'
'When there is no torch_npu in the env, skip fusion.')
return

def _extra_stream_scope_check(match: Match) -> bool:
"""
Checks if all nodes in the same stream.
"""
non_default_streams = set()
has_default = False

for node in match.nodes:
if node.op == "call_function":
current_stream = node.meta.get("stream_label")
if current_stream is None:
has_default = True
else:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPattern. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations."
)
return False

if has_default and len(non_default_streams) > 0:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPattern. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operatiosn.")
return False

return True

def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Pattern for AddRMSNormQuantSPPattern fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1

def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Replacement for the AddRMSNormQuantSPPattern fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
offset,
epsilon=epsilon)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
return quantized_output, out1

def get_inputs():
"""
Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [rms_norm_input, residual, rms_norm_weight, scale, offset]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of replacement_add_rms_norm_quant_sp_pattern has several critical issues that need to be addressed:

  1. Syntax Error: There's an indentation error in the if 'torch_npu' not in sys.modules: block, which will cause a SyntaxError.
  2. Incorrect Stream Check: The _extra_stream_scope_check logic is flawed. It incorrectly flags valid single non-default stream scenarios as cross-stream operations, preventing fusion.
  3. Missing Registration: The function defines a fusion pattern but never registers it with torchair.register_replacement. This makes the entire function ineffective.

The suggested code below fixes these issues.

def replacement_add_rms_norm_quant_sp_pattern(epsilon):
    if 'torch_npu' not in sys.modules:
        logger.info(
            'The AddRMSNormQuantSPPattern fusion will only be enabled in a torch npu env.'
            'When there is no torch_npu in the env, skip fusion.')
        return

    def _extra_stream_scope_check(match: Match) -> bool:
        """
        Checks if all nodes in the same stream.
        """
        non_default_streams = set()
        has_default = False

        for node in match.nodes:
            if node.op == "call_function":
                current_stream = node.meta.get("stream_label")
                if current_stream is None:
                    has_default = True
                else:
                    non_default_streams.add(current_stream)
                    if len(non_default_streams) > 1:
                        logger.debug(
                            f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPattern. "
                            f"Multiple streams found: {non_default_streams}. "
                            f"Fusion is not supported for cross-stream operations."
                        )
                        return False

        if has_default and len(non_default_streams) > 0:
            logger.debug(
                f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPattern. "
                f"Multiple streams found: {non_default_streams}. "
                f"Fusion is not supported for cross-stream operations.")
            return False
        
        return True

    def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
                rms_norm_weight: torch.Tensor, scale: torch.Tensor,
                offset: torch.Tensor):
        """
        Pattern for AddRMSNormQuantSPPattern fusion.
        """
        output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
                                                rms_norm_weight, epsilon)
        out0 = output[0]
        out1 = output[2]
        out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
        quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
                                                      torch.qint8, -1, False)
        return quantized_output, out1
    
    def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
                    rms_norm_weight: torch.Tensor, scale: torch.Tensor,
                    offset: torch.Tensor):
        """
        Replacement for the AddRMSNormQuantSPPattern fusion.
        """
        output = torch.ops.npu.npu_add_rms_norm_quant(
            rms_norm_input,
            residual,
            rms_norm_weight,
            # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
            1. / scale,
            offset,
            epsilon=epsilon)
        quantized_output = output[0]
        out1 = output[2]
        quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
            quantized_output, True)
        return quantized_output, out1

    def get_inputs():
        """
        Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern.
        """
        rms_norm_input = torch.randn(2, 4, device="npu")
        residual = torch.randn(2, 4, device="npu")
        rms_norm_weight = torch.randn(4, device="npu")
        scale = torch.ones(4, device="npu")
        offset = torch.zeros(4, device="npu")
        return [rms_norm_input, residual, rms_norm_weight, scale, offset]

    import torchair

    torchair.register_replacement(search_fn=pattern,
                                  replace_fn=replacement,
                                  example_inputs=get_inputs(),
                                  extra_check=_extra_stream_scope_check)

Comment on lines 301 to 282
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
if 'torch_npu' not in sys.modules:
logger.info(
'The AddRMSNormQuantSPPatternWithBias fusion will only be enabled in a torch npu env.'
'When there is no torch_npu in the env, skip fusion.')
return

def _extra_stream_scope_check(match: Match) -> bool:
"""
Checks if all nodes in the same stream.
"""
non_default_streams = set()
has_default = False

for node in match.nodes:
if node.op == "call_function":
current_stream = node.meta.get("stream_label")
if current_stream is None:
has_default = True
else:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPatternWithBias. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations."
)
return False

if has_default and len(non_default_streams) > 0:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPatternWithBias. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operatiosn.")
return False

return True

def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1

def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
offset,
epsilon=epsilon,
beta=bias)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
return quantized_output, out1

def get_inputs():
"""
Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the other new pattern, the implementation of replacement_add_rms_norm_quant_sp_pattern_with_bias has several critical issues:

  1. Syntax Error: An indentation error exists in the if 'torch_npu' not in sys.modules: block.
  2. Incorrect Stream Check: The _extra_stream_scope_check logic is incorrect and will prevent valid fusions.
  3. Missing Registration: The fusion pattern is not registered with torchair.register_replacement, rendering it non-functional.

The suggested code below fixes these issues.

def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
    if 'torch_npu' not in sys.modules:
        logger.info(
            'The AddRMSNormQuantSPPatternWithBias fusion will only be enabled in a torch npu env.'
            'When there is no torch_npu in the env, skip fusion.')
        return

    def _extra_stream_scope_check(match: Match) -> bool:
        """
        Checks if all nodes in the same stream.
        """
        non_default_streams = set()
        has_default = False

        for node in match.nodes:
            if node.op == "call_function":
                current_stream = node.meta.get("stream_label")
                if current_stream is None:
                    has_default = True
                else:
                    non_default_streams.add(current_stream)
                    if len(non_default_streams) > 1:
                        logger.debug(
                            f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPatternWithBias. "
                            f"Multiple streams found: {non_default_streams}. "
                            f"Fusion is not supported for cross-stream operations."
                        )
                        return False

        if has_default and len(non_default_streams) > 0:
            logger.debug(
                f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPatternWithBias. "
                f"Multiple streams found: {non_default_streams}. "
                f"Fusion is not supported for cross-stream operations.")
            return False
        
        return True

    def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
                rms_norm_weight: torch.Tensor, scale: torch.Tensor,
                offset: torch.Tensor, bias: torch.Tensor):
        """
        Pattern for AddRMSNormQuantSPPatternWithBias fusion.
        """
        output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
                                                rms_norm_weight, epsilon)
        out0 = output[0]
        out1 = output[2]
        out0 = out0 + bias
        out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
        quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
                                                      torch.qint8, -1, False)
        return quantized_output, out1
    
    def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
                    rms_norm_weight: torch.Tensor, scale: torch.Tensor,
                    offset: torch.Tensor, bias: torch.Tensor):
        """
        Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
        """
        output = torch.ops.npu.npu_add_rms_norm_quant(
            rms_norm_input,
            residual,
            rms_norm_weight,
            # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
            1. / scale,
            offset,
            epsilon=epsilon,
            beta=bias)
        quantized_output = output[0]
        out1 = output[2]
        quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
            quantized_output, True)
        return quantized_output, out1

    def get_inputs():
        """
        Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
        """
        rms_norm_input = torch.randn(2, 4, device="npu")
        residual = torch.randn(2, 4, device="npu")
        rms_norm_weight = torch.randn(4, device="npu")
        rmsnorm_bias = torch.randn(4, device="npu")
        scale = torch.ones(4, device="npu")
        offset = torch.zeros(4, device="npu")
        return [
            rms_norm_input, residual, rms_norm_weight, scale, offset,
            rmsnorm_bias
        ]

    import torchair

    torchair.register_replacement(search_fn=pattern,
                                  replace_fn=replacement,
                                  example_inputs=get_inputs(),
                                  extra_check=_extra_stream_scope_check)

@ForBetterCodeNine ForBetterCodeNine force-pushed the sp_pattern branch 3 times, most recently from 370a817 to 3eb4065 Compare January 4, 2026 03:20
@github-actions
Copy link

github-actions bot commented Jan 4, 2026

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@ForBetterCodeNine ForBetterCodeNine force-pushed the sp_pattern branch 7 times, most recently from b14e8a8 to 254f02f Compare January 6, 2026 01:23
extra_check=_extra_stream_scope_check)


@functools.lru_cache(None)
Copy link
Collaborator

@whx-sjtu whx-sjtu Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need lru_cache here? If not needed, please delete it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

'When there is no torch_npu in the env, skip fusion.')
return

def _extra_stream_scope_check(match: Match) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please abstract this function as a public method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Signed-off-by: cjian <[email protected]>
@wangxiyuan wangxiyuan merged commit bdedf3c into vllm-project:main Jan 7, 2026
19 checks passed
@ForBetterCodeNine ForBetterCodeNine deleted the sp_pattern branch January 7, 2026 01:05
Rozwel-dx pushed a commit to Rozwel-dx/vllm-ascend that referenced this pull request Jan 8, 2026
…as (vllm-project#5569)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for
SPPatternWithBias and Qwen3-32B model for SPPattern. Benchmark results
show that, compared to the unfused baseline, enabling this fusion pass
significantly improves inference throughput for W8A8 quantized models.
For more details can refer to the
RFC:vllm-project#4715
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: cjian <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants