Skip to content

Commit 6e91864

Browse files
add rms norm quant
Signed-off-by: cjian <2318164299@qq.com>
1 parent 799b41a commit 6e91864

File tree

2 files changed

+206
-73
lines changed

2 files changed

+206
-73
lines changed

tests/ut/compilation/test_add_rms_norm_quant.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@
1616
import sys
1717
from unittest import mock
1818

19+
import torch
20+
21+
22+
def get_inputs():
23+
"""
24+
Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
25+
"""
26+
rms_norm_input = torch.randn(2, 4)
27+
residual = torch.randn(2, 4)
28+
rms_norm_weight = torch.randn(4)
29+
rmsnorm_bias = torch.randn(4)
30+
scale = torch.ones(4)
31+
offset = torch.zeros(4)
32+
return [
33+
rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias
34+
]
35+
1936

2037
def _extra_stream_scope_check_for_test(match) -> bool:
2138
"""
@@ -93,3 +110,38 @@ def test_replacement_function_without_torch_npu(caplog):
93110
assert result is None
94111
except (ImportError, AttributeError):
95112
pass
113+
114+
def test_get_inputs_sp_pattern_with_bias():
115+
"""
116+
Test that get_inputs generates tensors with correct shapes and device.
117+
This test verifies the internal get_inputs function used in the pattern.
118+
"""
119+
try:
120+
import torch
121+
except ImportError:
122+
return # Skip if torch is not available
123+
124+
inputs = get_inputs()
125+
(
126+
rms_norm_input,
127+
residual,
128+
rms_norm_weight,
129+
scale,
130+
offset,
131+
rmsnorm_bias,
132+
) = inputs
133+
134+
# Verify shapes
135+
assert rms_norm_input.shape == (2, 4)
136+
assert residual.shape == (2, 4)
137+
assert rms_norm_weight.shape == (4, )
138+
assert rmsnorm_bias.shape == (4, )
139+
assert scale.shape == (4, )
140+
assert offset.shape == (4, )
141+
142+
# Verify number of inputs
143+
assert len(inputs) == 6
144+
145+
# Verify specific values
146+
assert torch.all(scale == 1.0)
147+
assert torch.all(offset == 0.0)

vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py

Lines changed: 154 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -23,45 +23,41 @@
2323
from vllm.logger import logger
2424

2525

26+
def _extra_stream_scope_check(match: Match) -> bool:
27+
"""
28+
Checks if all nodes in the same stream.
29+
"""
30+
non_default_streams = set()
31+
has_default = False
32+
33+
for node in match.nodes:
34+
if node.op == "call_function":
35+
current_stream = node.meta.get("stream_label")
36+
if current_stream is None:
37+
has_default = True
38+
else:
39+
non_default_streams.add(current_stream)
40+
if len(non_default_streams) > 1:
41+
logger.debug(
42+
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
43+
f"Multiple streams found: {non_default_streams}. "
44+
f"Fusion is not supported for cross-stream operations."
45+
)
46+
return False
47+
48+
if has_default and len(non_default_streams) > 0:
49+
logger.debug(
50+
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
51+
f"Multiple streams found: {non_default_streams}. "
52+
f"Fusion is not supported for cross-stream operations.")
53+
return False
54+
55+
return True
56+
57+
2658
@functools.lru_cache(None)
2759
# The replacement registered here will be actually executed after AOT.
2860
def replacement_add_rms_norm_quant(epsilon):
29-
if 'torch_npu' not in sys.modules:
30-
logger.info(
31-
'The AddRMSNormQuant fusion will only be enabled in a torch npu env.'
32-
'When there is no torch_npu in the env, skip fusion.')
33-
return
34-
35-
def _extra_stream_scope_check(match: Match) -> bool:
36-
"""
37-
Checks if all nodes in the same stream.
38-
"""
39-
non_default_streams = set()
40-
has_default = False
41-
42-
for node in match.nodes:
43-
if node.op == "call_function":
44-
current_stream = node.meta.get("stream_label")
45-
if current_stream is None:
46-
has_default = True
47-
else:
48-
non_default_streams.add(current_stream)
49-
if len(non_default_streams) > 1:
50-
logger.debug(
51-
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
52-
f"Multiple streams found: {non_default_streams}. "
53-
f"Fusion is not supported for cross-stream operations."
54-
)
55-
return False
56-
57-
if has_default and len(non_default_streams) > 0:
58-
logger.debug(
59-
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
60-
f"Multiple streams found: {non_default_streams}. "
61-
f"Fusion is not supported for cross-stream operations.")
62-
return False
63-
64-
return True
6561

6662
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
6763
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
@@ -114,45 +110,8 @@ def get_inputs():
114110
extra_check=_extra_stream_scope_check)
115111

116112

117-
@functools.lru_cache(None)
118113
# The replacement registered here will be actually executed after AOT.
119114
def replacement_add_rms_norm_quant_with_bias(epsilon):
120-
if 'torch_npu' not in sys.modules:
121-
logger.info(
122-
'The AddRMSNormQuantWithBias fusion will only be enabled in a torch npu env.'
123-
'When there is no torch_npu in the env, skip fusion.')
124-
return
125-
126-
def _extra_stream_scope_check(match: Match) -> bool:
127-
"""
128-
Checks if all nodes in the same stream.
129-
"""
130-
non_default_streams = set()
131-
has_default = False
132-
133-
for node in match.nodes:
134-
if node.op == "call_function":
135-
current_stream = node.meta.get("stream_label")
136-
if current_stream is None:
137-
has_default = True
138-
else:
139-
non_default_streams.add(current_stream)
140-
if len(non_default_streams) > 1:
141-
logger.debug(
142-
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
143-
f"Multiple streams found: {non_default_streams}. "
144-
f"Fusion is not supported for cross-stream operations."
145-
)
146-
return False
147-
148-
if has_default and len(non_default_streams) > 0:
149-
logger.debug(
150-
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
151-
f"Multiple streams found: {non_default_streams}. "
152-
f"Fusion is not supported for cross-stream operations.")
153-
return False
154-
155-
return True
156115

157116
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
158117
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
@@ -211,6 +170,126 @@ def get_inputs():
211170
extra_check=_extra_stream_scope_check)
212171

213172

173+
# The replacement registered here will be actually executed after AOT.
174+
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
175+
176+
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
177+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
178+
offset: torch.Tensor):
179+
"""
180+
Pattern for AddRMSNormQuantSPPattern fusion.
181+
"""
182+
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
183+
rms_norm_weight, epsilon)
184+
out0 = output[0]
185+
out1 = output[2]
186+
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
187+
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
188+
torch.qint8, -1, False)
189+
return quantized_output, out1
190+
191+
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
192+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
193+
offset: torch.Tensor):
194+
"""
195+
Replacement for the AddRMSNormQuantSPPattern fusion.
196+
"""
197+
output = torch.ops.npu.npu_add_rms_norm_quant(
198+
rms_norm_input,
199+
residual,
200+
rms_norm_weight,
201+
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
202+
1. / scale,
203+
offset,
204+
epsilon=epsilon)
205+
quantized_output = output[0]
206+
out1 = output[2]
207+
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
208+
quantized_output, True)
209+
return quantized_output, out1
210+
211+
def get_inputs():
212+
"""
213+
Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern.
214+
"""
215+
rms_norm_input = torch.randn(2, 4, device="npu")
216+
residual = torch.randn(2, 4, device="npu")
217+
rms_norm_weight = torch.randn(4, device="npu")
218+
scale = torch.ones(4, device="npu")
219+
offset = torch.zeros(4, device="npu")
220+
return [rms_norm_input, residual, rms_norm_weight, scale, offset]
221+
222+
import torchair
223+
224+
torchair.register_replacement(search_fn=pattern,
225+
replace_fn=replacement,
226+
example_inputs=get_inputs(),
227+
extra_check=_extra_stream_scope_check)
228+
229+
230+
# The replacement registered here will be actually executed after AOT.
231+
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
232+
233+
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
234+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
235+
offset: torch.Tensor, bias: torch.Tensor):
236+
"""
237+
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
238+
"""
239+
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
240+
rms_norm_weight, epsilon)
241+
out0 = output[0]
242+
out1 = output[2]
243+
out0 = out0 + bias
244+
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
245+
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
246+
torch.qint8, -1, False)
247+
return quantized_output, out1
248+
249+
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
250+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
251+
offset: torch.Tensor, bias: torch.Tensor):
252+
"""
253+
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
254+
"""
255+
output = torch.ops.npu.npu_add_rms_norm_quant(
256+
rms_norm_input,
257+
residual,
258+
rms_norm_weight,
259+
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
260+
1. / scale,
261+
offset,
262+
epsilon=epsilon,
263+
beta=bias)
264+
quantized_output = output[0]
265+
out1 = output[2]
266+
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
267+
quantized_output, True)
268+
return quantized_output, out1
269+
270+
def get_inputs():
271+
"""
272+
Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
273+
"""
274+
rms_norm_input = torch.randn(2, 4, device="npu")
275+
residual = torch.randn(2, 4, device="npu")
276+
rms_norm_weight = torch.randn(4, device="npu")
277+
rmsnorm_bias = torch.randn(4, device="npu")
278+
scale = torch.ones(4, device="npu")
279+
offset = torch.zeros(4, device="npu")
280+
return [
281+
rms_norm_input, residual, rms_norm_weight, scale, offset,
282+
rmsnorm_bias
283+
]
284+
285+
import torchair
286+
287+
torchair.register_replacement(search_fn=pattern,
288+
replace_fn=replacement,
289+
example_inputs=get_inputs(),
290+
extra_check=_extra_stream_scope_check)
291+
292+
214293
# register converter for pass
215294
common_epsilons = [1e-5, 1e-6]
216295
for eps in common_epsilons:
@@ -219,3 +298,5 @@ def get_inputs():
219298
)
220299
replacement_add_rms_norm_quant(eps)
221300
replacement_add_rms_norm_quant_with_bias(eps)
301+
replacement_add_rms_norm_quant_sp_pattern(eps)
302+
replacement_add_rms_norm_quant_sp_pattern_with_bias(eps)

0 commit comments

Comments
 (0)