2323from vllm .logger import logger
2424
2525
26+ def _extra_stream_scope_check (match : Match ) -> bool :
27+ """
28+ Checks if all nodes in the same stream.
29+ """
30+ non_default_streams = set ()
31+ has_default = False
32+
33+ for node in match .nodes :
34+ if node .op == "call_function" :
35+ current_stream = node .meta .get ("stream_label" )
36+ if current_stream is None :
37+ has_default = True
38+ else :
39+ non_default_streams .add (current_stream )
40+ if len (non_default_streams ) > 1 :
41+ logger .debug (
42+ f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
43+ f"Multiple streams found: { non_default_streams } . "
44+ f"Fusion is not supported for cross-stream operations."
45+ )
46+ return False
47+
48+ if has_default and len (non_default_streams ) > 0 :
49+ logger .debug (
50+ f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
51+ f"Multiple streams found: { non_default_streams } . "
52+ f"Fusion is not supported for cross-stream operations." )
53+ return False
54+
55+ return True
56+
57+
2658@functools .lru_cache (None )
2759# The replacement registered here will be actually executed after AOT.
2860def replacement_add_rms_norm_quant (epsilon ):
29- if 'torch_npu' not in sys .modules :
30- logger .info (
31- 'The AddRMSNormQuant fusion will only be enabled in a torch npu env.'
32- 'When there is no torch_npu in the env, skip fusion.' )
33- return
34-
35- def _extra_stream_scope_check (match : Match ) -> bool :
36- """
37- Checks if all nodes in the same stream.
38- """
39- non_default_streams = set ()
40- has_default = False
41-
42- for node in match .nodes :
43- if node .op == "call_function" :
44- current_stream = node .meta .get ("stream_label" )
45- if current_stream is None :
46- has_default = True
47- else :
48- non_default_streams .add (current_stream )
49- if len (non_default_streams ) > 1 :
50- logger .debug (
51- f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
52- f"Multiple streams found: { non_default_streams } . "
53- f"Fusion is not supported for cross-stream operations."
54- )
55- return False
56-
57- if has_default and len (non_default_streams ) > 0 :
58- logger .debug (
59- f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
60- f"Multiple streams found: { non_default_streams } . "
61- f"Fusion is not supported for cross-stream operations." )
62- return False
63-
64- return True
6561
6662 def pattern (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
6763 rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
@@ -114,45 +110,8 @@ def get_inputs():
114110 extra_check = _extra_stream_scope_check )
115111
116112
117- @functools .lru_cache (None )
118113# The replacement registered here will be actually executed after AOT.
119114def replacement_add_rms_norm_quant_with_bias (epsilon ):
120- if 'torch_npu' not in sys .modules :
121- logger .info (
122- 'The AddRMSNormQuantWithBias fusion will only be enabled in a torch npu env.'
123- 'When there is no torch_npu in the env, skip fusion.' )
124- return
125-
126- def _extra_stream_scope_check (match : Match ) -> bool :
127- """
128- Checks if all nodes in the same stream.
129- """
130- non_default_streams = set ()
131- has_default = False
132-
133- for node in match .nodes :
134- if node .op == "call_function" :
135- current_stream = node .meta .get ("stream_label" )
136- if current_stream is None :
137- has_default = True
138- else :
139- non_default_streams .add (current_stream )
140- if len (non_default_streams ) > 1 :
141- logger .debug (
142- f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
143- f"Multiple streams found: { non_default_streams } . "
144- f"Fusion is not supported for cross-stream operations."
145- )
146- return False
147-
148- if has_default and len (non_default_streams ) > 0 :
149- logger .debug (
150- f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
151- f"Multiple streams found: { non_default_streams } . "
152- f"Fusion is not supported for cross-stream operations." )
153- return False
154-
155- return True
156115
157116 def pattern (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
158117 rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
@@ -211,6 +170,126 @@ def get_inputs():
211170 extra_check = _extra_stream_scope_check )
212171
213172
173+ # The replacement registered here will be actually executed after AOT.
174+ def replacement_add_rms_norm_quant_sp_pattern (epsilon ):
175+
176+ def pattern (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
177+ rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
178+ offset : torch .Tensor ):
179+ """
180+ Pattern for AddRMSNormQuantSPPattern fusion.
181+ """
182+ output = torch .ops .npu .npu_add_rms_norm (rms_norm_input , residual ,
183+ rms_norm_weight , epsilon )
184+ out0 = output [0 ]
185+ out1 = output [2 ]
186+ out0 = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (out0 , True )
187+ quantized_output = torch .ops .npu .npu_quantize (out0 , scale , offset ,
188+ torch .qint8 , - 1 , False )
189+ return quantized_output , out1
190+
191+ def replacement (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
192+ rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
193+ offset : torch .Tensor ):
194+ """
195+ Replacement for the AddRMSNormQuantSPPattern fusion.
196+ """
197+ output = torch .ops .npu .npu_add_rms_norm_quant (
198+ rms_norm_input ,
199+ residual ,
200+ rms_norm_weight ,
201+ # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
202+ 1. / scale ,
203+ offset ,
204+ epsilon = epsilon )
205+ quantized_output = output [0 ]
206+ out1 = output [2 ]
207+ quantized_output = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
208+ quantized_output , True )
209+ return quantized_output , out1
210+
211+ def get_inputs ():
212+ """
213+ Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern.
214+ """
215+ rms_norm_input = torch .randn (2 , 4 , device = "npu" )
216+ residual = torch .randn (2 , 4 , device = "npu" )
217+ rms_norm_weight = torch .randn (4 , device = "npu" )
218+ scale = torch .ones (4 , device = "npu" )
219+ offset = torch .zeros (4 , device = "npu" )
220+ return [rms_norm_input , residual , rms_norm_weight , scale , offset ]
221+
222+ import torchair
223+
224+ torchair .register_replacement (search_fn = pattern ,
225+ replace_fn = replacement ,
226+ example_inputs = get_inputs (),
227+ extra_check = _extra_stream_scope_check )
228+
229+
230+ # The replacement registered here will be actually executed after AOT.
231+ def replacement_add_rms_norm_quant_sp_pattern_with_bias (epsilon ):
232+
233+ def pattern (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
234+ rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
235+ offset : torch .Tensor , bias : torch .Tensor ):
236+ """
237+ Pattern for AddRMSNormQuantSPPatternWithBias fusion.
238+ """
239+ output = torch .ops .npu .npu_add_rms_norm (rms_norm_input , residual ,
240+ rms_norm_weight , epsilon )
241+ out0 = output [0 ]
242+ out1 = output [2 ]
243+ out0 = out0 + bias
244+ out0 = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (out0 , True )
245+ quantized_output = torch .ops .npu .npu_quantize (out0 , scale , offset ,
246+ torch .qint8 , - 1 , False )
247+ return quantized_output , out1
248+
249+ def replacement (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
250+ rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
251+ offset : torch .Tensor , bias : torch .Tensor ):
252+ """
253+ Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
254+ """
255+ output = torch .ops .npu .npu_add_rms_norm_quant (
256+ rms_norm_input ,
257+ residual ,
258+ rms_norm_weight ,
259+ # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
260+ 1. / scale ,
261+ offset ,
262+ epsilon = epsilon ,
263+ beta = bias )
264+ quantized_output = output [0 ]
265+ out1 = output [2 ]
266+ quantized_output = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
267+ quantized_output , True )
268+ return quantized_output , out1
269+
270+ def get_inputs ():
271+ """
272+ Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
273+ """
274+ rms_norm_input = torch .randn (2 , 4 , device = "npu" )
275+ residual = torch .randn (2 , 4 , device = "npu" )
276+ rms_norm_weight = torch .randn (4 , device = "npu" )
277+ rmsnorm_bias = torch .randn (4 , device = "npu" )
278+ scale = torch .ones (4 , device = "npu" )
279+ offset = torch .zeros (4 , device = "npu" )
280+ return [
281+ rms_norm_input , residual , rms_norm_weight , scale , offset ,
282+ rmsnorm_bias
283+ ]
284+
285+ import torchair
286+
287+ torchair .register_replacement (search_fn = pattern ,
288+ replace_fn = replacement ,
289+ example_inputs = get_inputs (),
290+ extra_check = _extra_stream_scope_check )
291+
292+
214293# register converter for pass
215294common_epsilons = [1e-5 , 1e-6 ]
216295for eps in common_epsilons :
@@ -219,3 +298,5 @@ def get_inputs():
219298 )
220299 replacement_add_rms_norm_quant (eps )
221300 replacement_add_rms_norm_quant_with_bias (eps )
301+ replacement_add_rms_norm_quant_sp_pattern (eps )
302+ replacement_add_rms_norm_quant_sp_pattern_with_bias (eps )
0 commit comments