Skip to content

Commit 254f02f

Browse files
add rms norm quant
Signed-off-by: cjian <2318164299@qq.com>
1 parent 799b41a commit 254f02f

File tree

2 files changed

+273
-0
lines changed

2 files changed

+273
-0
lines changed

tests/ut/compilation/test_add_rms_norm_quant.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@
1616
import sys
1717
from unittest import mock
1818

19+
import torch
20+
21+
22+
def get_inputs():
23+
"""
24+
Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
25+
"""
26+
rms_norm_input = torch.randn(2, 4)
27+
residual = torch.randn(2, 4)
28+
rms_norm_weight = torch.randn(4)
29+
rmsnorm_bias = torch.randn(4)
30+
scale = torch.ones(4)
31+
offset = torch.zeros(4)
32+
return [
33+
rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias
34+
]
35+
1936

2037
def _extra_stream_scope_check_for_test(match) -> bool:
2138
"""
@@ -93,3 +110,63 @@ def test_replacement_function_without_torch_npu(caplog):
93110
assert result is None
94111
except (ImportError, AttributeError):
95112
pass
113+
114+
115+
def test_replacement_add_rms_norm_quant_sp_pattern_with_bias_without_torch_npu(
116+
caplog):
117+
"""
118+
Test that replacement_add_rms_norm_quant_sp_pattern_with_bias returns None
119+
when torch_npu is not available.
120+
"""
121+
with mock.patch.dict(sys.modules, {
122+
'torch_npu': None,
123+
'torchair': None,
124+
}):
125+
if 'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant' in sys.modules:
126+
del sys.modules[
127+
'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant']
128+
129+
try:
130+
from vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant import \
131+
replacement_add_rms_norm_quant_sp_pattern_with_bias
132+
result = replacement_add_rms_norm_quant_sp_pattern_with_bias(
133+
epsilon=1e-5)
134+
assert result is None
135+
except (ImportError, AttributeError):
136+
pass
137+
138+
139+
def test_get_inputs_sp_pattern_with_bias():
140+
"""
141+
Test that get_inputs generates tensors with correct shapes and device.
142+
This test verifies the internal get_inputs function used in the pattern.
143+
"""
144+
try:
145+
import torch
146+
except ImportError:
147+
return # Skip if torch is not available
148+
149+
inputs = get_inputs()
150+
(
151+
rms_norm_input,
152+
residual,
153+
rms_norm_weight,
154+
scale,
155+
offset,
156+
rmsnorm_bias,
157+
) = inputs
158+
159+
# Verify shapes
160+
assert rms_norm_input.shape == (2, 4)
161+
assert residual.shape == (2, 4)
162+
assert rms_norm_weight.shape == (4, )
163+
assert rmsnorm_bias.shape == (4, )
164+
assert scale.shape == (4, )
165+
assert offset.shape == (4, )
166+
167+
# Verify number of inputs
168+
assert len(inputs) == 6
169+
170+
# Verify specific values
171+
assert torch.all(scale == 1.0)
172+
assert torch.all(offset == 0.0)

vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,200 @@ def get_inputs():
211211
extra_check=_extra_stream_scope_check)
212212

213213

214+
@functools.lru_cache(None)
215+
# The replacement registered here will be actually executed after AOT.
216+
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
217+
if 'torch_npu' not in sys.modules:
218+
logger.info(
219+
'The AddRMSNormQuantSPPattern fusion will only be enabled in a torch npu env.'
220+
'When there is no torch_npu in the env, skip fusion.')
221+
return
222+
223+
def _extra_stream_scope_check(match: Match) -> bool:
224+
"""
225+
Checks if all nodes in the same stream.
226+
"""
227+
non_default_streams = set()
228+
has_default = False
229+
230+
for node in match.nodes:
231+
if node.op == "call_function":
232+
current_stream = node.meta.get("stream_label")
233+
if current_stream is None:
234+
has_default = True
235+
else:
236+
non_default_streams.add(current_stream)
237+
if len(non_default_streams) > 1:
238+
logger.debug(
239+
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
240+
f"Multiple streams found: {non_default_streams}. "
241+
f"Fusion is not supported for cross-stream operations."
242+
)
243+
return False
244+
245+
if has_default and len(non_default_streams) > 0:
246+
logger.debug(
247+
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
248+
f"Multiple streams found: {non_default_streams}. "
249+
f"Fusion is not supported for cross-stream operations.")
250+
return False
251+
252+
return True
253+
254+
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
255+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
256+
offset: torch.Tensor):
257+
"""
258+
Pattern for AddRMSNormQuantSPPattern fusion.
259+
"""
260+
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
261+
rms_norm_weight, epsilon)
262+
out0 = output[0]
263+
out1 = output[2]
264+
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
265+
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
266+
torch.qint8, -1, False)
267+
return quantized_output, out1
268+
269+
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
270+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
271+
offset: torch.Tensor):
272+
"""
273+
Replacement for the AddRMSNormQuantSPPattern fusion.
274+
"""
275+
output = torch.ops.npu.npu_add_rms_norm_quant(
276+
rms_norm_input,
277+
residual,
278+
rms_norm_weight,
279+
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
280+
1. / scale,
281+
offset,
282+
epsilon=epsilon)
283+
quantized_output = output[0]
284+
out1 = output[2]
285+
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
286+
quantized_output, True)
287+
return quantized_output, out1
288+
289+
def get_inputs():
290+
"""
291+
Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern.
292+
"""
293+
rms_norm_input = torch.randn(2, 4, device="npu")
294+
residual = torch.randn(2, 4, device="npu")
295+
rms_norm_weight = torch.randn(4, device="npu")
296+
scale = torch.ones(4, device="npu")
297+
offset = torch.zeros(4, device="npu")
298+
return [rms_norm_input, residual, rms_norm_weight, scale, offset]
299+
300+
import torchair
301+
302+
torchair.register_replacement(search_fn=pattern,
303+
replace_fn=replacement,
304+
example_inputs=get_inputs(),
305+
extra_check=_extra_stream_scope_check)
306+
307+
308+
@functools.lru_cache(None)
309+
# The replacement registered here will be actually executed after AOT.
310+
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
311+
if 'torch_npu' not in sys.modules:
312+
logger.info(
313+
'The AddRMSNormQuantSPPatternWithBias fusion will only be enabled in a torch npu env.'
314+
'When there is no torch_npu in the env, skip fusion.')
315+
return
316+
317+
def _extra_stream_scope_check(match: Match) -> bool:
318+
"""
319+
Checks if all nodes in the same stream.
320+
"""
321+
non_default_streams = set()
322+
has_default = False
323+
324+
for node in match.nodes:
325+
if node.op == "call_function":
326+
current_stream = node.meta.get("stream_label")
327+
if current_stream is None:
328+
has_default = True
329+
else:
330+
non_default_streams.add(current_stream)
331+
if len(non_default_streams) > 1:
332+
logger.debug(
333+
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
334+
f"Multiple streams found: {non_default_streams}. "
335+
f"Fusion is not supported for cross-stream operations."
336+
)
337+
return False
338+
339+
if has_default and len(non_default_streams) > 0:
340+
logger.debug(
341+
f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
342+
f"Multiple streams found: {non_default_streams}. "
343+
f"Fusion is not supported for cross-stream operations.")
344+
return False
345+
346+
return True
347+
348+
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
349+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
350+
offset: torch.Tensor, bias: torch.Tensor):
351+
"""
352+
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
353+
"""
354+
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
355+
rms_norm_weight, epsilon)
356+
out0 = output[0]
357+
out1 = output[2]
358+
out0 = out0 + bias
359+
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
360+
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
361+
torch.qint8, -1, False)
362+
return quantized_output, out1
363+
364+
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
365+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
366+
offset: torch.Tensor, bias: torch.Tensor):
367+
"""
368+
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
369+
"""
370+
output = torch.ops.npu.npu_add_rms_norm_quant(
371+
rms_norm_input,
372+
residual,
373+
rms_norm_weight,
374+
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
375+
1. / scale,
376+
offset,
377+
epsilon=epsilon,
378+
beta=bias)
379+
quantized_output = output[0]
380+
out1 = output[2]
381+
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
382+
quantized_output, True)
383+
return quantized_output, out1
384+
385+
def get_inputs():
386+
"""
387+
Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
388+
"""
389+
rms_norm_input = torch.randn(2, 4, device="npu")
390+
residual = torch.randn(2, 4, device="npu")
391+
rms_norm_weight = torch.randn(4, device="npu")
392+
rmsnorm_bias = torch.randn(4, device="npu")
393+
scale = torch.ones(4, device="npu")
394+
offset = torch.zeros(4, device="npu")
395+
return [
396+
rms_norm_input, residual, rms_norm_weight, scale, offset,
397+
rmsnorm_bias
398+
]
399+
400+
import torchair
401+
402+
torchair.register_replacement(search_fn=pattern,
403+
replace_fn=replacement,
404+
example_inputs=get_inputs(),
405+
extra_check=_extra_stream_scope_check)
406+
407+
214408
# register converter for pass
215409
common_epsilons = [1e-5, 1e-6]
216410
for eps in common_epsilons:
@@ -219,3 +413,5 @@ def get_inputs():
219413
)
220414
replacement_add_rms_norm_quant(eps)
221415
replacement_add_rms_norm_quant_with_bias(eps)
416+
replacement_add_rms_norm_quant_sp_pattern(eps)
417+
replacement_add_rms_norm_quant_sp_pattern_with_bias(eps)

0 commit comments

Comments
 (0)