-
Notifications
You must be signed in to change notification settings - Fork 717
[Graph][Fusion] Add AddRMSNormSPPattern and AddRMSNormSPPatternWithBias #5569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces two new fusion patterns, AddRMSNormSPPattern and AddRMSNormSPPatternWithBias. However, the implementation contains several critical bugs that prevent it from working correctly. Both new functions have Python syntax errors due to incorrect indentation, flawed logic for stream checking that would prevent fusions, and most importantly, they are missing the call to register the fusion patterns, making them completely ineffective. I have provided two comprehensive review comments with code suggestions that fix all these issues for both functions.
| def replacement_add_rms_norm_quant_sp_pattern(epsilon): | ||
| if 'torch_npu' not in sys.modules: | ||
| logger.info( | ||
| 'The AddRMSNormQuantSPPattern fusion will only be enabled in a torch npu env.' | ||
| 'When there is no torch_npu in the env, skip fusion.') | ||
| return | ||
|
|
||
| def _extra_stream_scope_check(match: Match) -> bool: | ||
| """ | ||
| Checks if all nodes in the same stream. | ||
| """ | ||
| non_default_streams = set() | ||
| has_default = False | ||
|
|
||
| for node in match.nodes: | ||
| if node.op == "call_function": | ||
| current_stream = node.meta.get("stream_label") | ||
| if current_stream is None: | ||
| has_default = True | ||
| else: | ||
| logger.debug( | ||
| f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPattern. " | ||
| f"Multiple streams found: {non_default_streams}. " | ||
| f"Fusion is not supported for cross-stream operations." | ||
| ) | ||
| return False | ||
|
|
||
| if has_default and len(non_default_streams) > 0: | ||
| logger.debug( | ||
| f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPattern. " | ||
| f"Multiple streams found: {non_default_streams}. " | ||
| f"Fusion is not supported for cross-stream operatiosn.") | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
| def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, | ||
| rms_norm_weight: torch.Tensor, scale: torch.Tensor, | ||
| offset: torch.Tensor): | ||
| """ | ||
| Pattern for AddRMSNormQuantSPPattern fusion. | ||
| """ | ||
| output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, | ||
| rms_norm_weight, epsilon) | ||
| out0 = output[0] | ||
| out1 = output[2] | ||
| out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) | ||
| quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, | ||
| torch.qint8, -1, False) | ||
| return quantized_output, out1 | ||
|
|
||
| def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, | ||
| rms_norm_weight: torch.Tensor, scale: torch.Tensor, | ||
| offset: torch.Tensor): | ||
| """ | ||
| Replacement for the AddRMSNormQuantSPPattern fusion. | ||
| """ | ||
| output = torch.ops.npu.npu_add_rms_norm_quant( | ||
| rms_norm_input, | ||
| residual, | ||
| rms_norm_weight, | ||
| # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. | ||
| 1. / scale, | ||
| offset, | ||
| epsilon=epsilon) | ||
| quantized_output = output[0] | ||
| out1 = output[2] | ||
| quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( | ||
| quantized_output, True) | ||
| return quantized_output, out1 | ||
|
|
||
| def get_inputs(): | ||
| """ | ||
| Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern. | ||
| """ | ||
| rms_norm_input = torch.randn(2, 4, device="npu") | ||
| residual = torch.randn(2, 4, device="npu") | ||
| rms_norm_weight = torch.randn(4, device="npu") | ||
| scale = torch.ones(4, device="npu") | ||
| offset = torch.zeros(4, device="npu") | ||
| return [rms_norm_input, residual, rms_norm_weight, scale, offset] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of replacement_add_rms_norm_quant_sp_pattern has several critical issues that need to be addressed:
- Syntax Error: There's an indentation error in the
if 'torch_npu' not in sys.modules:block, which will cause aSyntaxError. - Incorrect Stream Check: The
_extra_stream_scope_checklogic is flawed. It incorrectly flags valid single non-default stream scenarios as cross-stream operations, preventing fusion. - Missing Registration: The function defines a fusion pattern but never registers it with
torchair.register_replacement. This makes the entire function ineffective.
The suggested code below fixes these issues.
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
if 'torch_npu' not in sys.modules:
logger.info(
'The AddRMSNormQuantSPPattern fusion will only be enabled in a torch npu env.'
'When there is no torch_npu in the env, skip fusion.')
return
def _extra_stream_scope_check(match: Match) -> bool:
"""
Checks if all nodes in the same stream.
"""
non_default_streams = set()
has_default = False
for node in match.nodes:
if node.op == "call_function":
current_stream = node.meta.get("stream_label")
if current_stream is None:
has_default = True
else:
non_default_streams.add(current_stream)
if len(non_default_streams) > 1:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPattern. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations."
)
return False
if has_default and len(non_default_streams) > 0:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPattern. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.")
return False
return True
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Pattern for AddRMSNormQuantSPPattern fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Replacement for the AddRMSNormQuantSPPattern fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
offset,
epsilon=epsilon)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
return quantized_output, out1
def get_inputs():
"""
Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [rms_norm_input, residual, rms_norm_weight, scale, offset]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)| def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon): | ||
| if 'torch_npu' not in sys.modules: | ||
| logger.info( | ||
| 'The AddRMSNormQuantSPPatternWithBias fusion will only be enabled in a torch npu env.' | ||
| 'When there is no torch_npu in the env, skip fusion.') | ||
| return | ||
|
|
||
| def _extra_stream_scope_check(match: Match) -> bool: | ||
| """ | ||
| Checks if all nodes in the same stream. | ||
| """ | ||
| non_default_streams = set() | ||
| has_default = False | ||
|
|
||
| for node in match.nodes: | ||
| if node.op == "call_function": | ||
| current_stream = node.meta.get("stream_label") | ||
| if current_stream is None: | ||
| has_default = True | ||
| else: | ||
| logger.debug( | ||
| f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPatternWithBias. " | ||
| f"Multiple streams found: {non_default_streams}. " | ||
| f"Fusion is not supported for cross-stream operations." | ||
| ) | ||
| return False | ||
|
|
||
| if has_default and len(non_default_streams) > 0: | ||
| logger.debug( | ||
| f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPatternWithBias. " | ||
| f"Multiple streams found: {non_default_streams}. " | ||
| f"Fusion is not supported for cross-stream operatiosn.") | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
| def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, | ||
| rms_norm_weight: torch.Tensor, scale: torch.Tensor, | ||
| offset: torch.Tensor, bias: torch.Tensor): | ||
| """ | ||
| Pattern for AddRMSNormQuantSPPatternWithBias fusion. | ||
| """ | ||
| output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, | ||
| rms_norm_weight, epsilon) | ||
| out0 = output[0] | ||
| out1 = output[2] | ||
| out0 = out0 + bias | ||
| out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) | ||
| quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, | ||
| torch.qint8, -1, False) | ||
| return quantized_output, out1 | ||
|
|
||
| def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, | ||
| rms_norm_weight: torch.Tensor, scale: torch.Tensor, | ||
| offset: torch.Tensor, bias: torch.Tensor): | ||
| """ | ||
| Replacement for the AddRMSNormQuantSPPatternWithBias fusion. | ||
| """ | ||
| output = torch.ops.npu.npu_add_rms_norm_quant( | ||
| rms_norm_input, | ||
| residual, | ||
| rms_norm_weight, | ||
| # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. | ||
| 1. / scale, | ||
| offset, | ||
| epsilon=epsilon, | ||
| beta=bias) | ||
| quantized_output = output[0] | ||
| out1 = output[2] | ||
| quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( | ||
| quantized_output, True) | ||
| return quantized_output, out1 | ||
|
|
||
| def get_inputs(): | ||
| """ | ||
| Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern. | ||
| """ | ||
| rms_norm_input = torch.randn(2, 4, device="npu") | ||
| residual = torch.randn(2, 4, device="npu") | ||
| rms_norm_weight = torch.randn(4, device="npu") | ||
| rmsnorm_bias = torch.randn(4, device="npu") | ||
| scale = torch.ones(4, device="npu") | ||
| offset = torch.zeros(4, device="npu") | ||
| return [ | ||
| rms_norm_input, residual, rms_norm_weight, scale, offset, | ||
| rmsnorm_bias | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the other new pattern, the implementation of replacement_add_rms_norm_quant_sp_pattern_with_bias has several critical issues:
- Syntax Error: An indentation error exists in the
if 'torch_npu' not in sys.modules:block. - Incorrect Stream Check: The
_extra_stream_scope_checklogic is incorrect and will prevent valid fusions. - Missing Registration: The fusion pattern is not registered with
torchair.register_replacement, rendering it non-functional.
The suggested code below fixes these issues.
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
if 'torch_npu' not in sys.modules:
logger.info(
'The AddRMSNormQuantSPPatternWithBias fusion will only be enabled in a torch npu env.'
'When there is no torch_npu in the env, skip fusion.')
return
def _extra_stream_scope_check(match: Match) -> bool:
"""
Checks if all nodes in the same stream.
"""
non_default_streams = set()
has_default = False
for node in match.nodes:
if node.op == "call_function":
current_stream = node.meta.get("stream_label")
if current_stream is None:
has_default = True
else:
non_default_streams.add(current_stream)
if len(non_default_streams) > 1:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPatternWithBias. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations."
)
return False
if has_default and len(non_default_streams) > 0:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuantSPPatternWithBias. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.")
return False
return True
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
offset,
epsilon=epsilon,
beta=bias)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
return quantized_output, out1
def get_inputs():
"""
Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)370a817 to
3eb4065
Compare
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
b14e8a8 to
254f02f
Compare
| extra_check=_extra_stream_scope_check) | ||
|
|
||
|
|
||
| @functools.lru_cache(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need lru_cache here? If not needed, please delete it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
| 'When there is no torch_npu in the env, skip fusion.') | ||
| return | ||
|
|
||
| def _extra_stream_scope_check(match: Match) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please abstract this function as a public method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
254f02f to
6e91864
Compare
Signed-off-by: cjian <[email protected]>
6e91864 to
213dbc0
Compare
…as (vllm-project#5569) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model for SPPatternWithBias and Qwen3-32B model for SPPattern. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@7157596 Signed-off-by: cjian <[email protected]>
What this PR does / why we need it?
This PR builds upon PR #5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph.
For validation, we switched to the Qwen3-235B-A22B-W8A8 model for SPPatternWithBias and Qwen3-32B model for SPPattern. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models.
For more details can refer to the RFC:#4715
Does this PR introduce any user-facing change?
no
How was this patch tested?