@@ -211,6 +211,200 @@ def get_inputs():
211211 extra_check = _extra_stream_scope_check )
212212
213213
214+ @functools .lru_cache (None )
215+ # The replacement registered here will be actually executed after AOT.
216+ def replacement_add_rms_norm_quant_sp_pattern (epsilon ):
217+ if 'torch_npu' not in sys .modules :
218+ logger .info (
219+ 'The AddRMSNormQuantSPPattern fusion will only be enabled in a torch npu env.'
220+ 'When there is no torch_npu in the env, skip fusion.' )
221+ return
222+
223+ def _extra_stream_scope_check (match : Match ) -> bool :
224+ """
225+ Checks if all nodes in the same stream.
226+ """
227+ non_default_streams = set ()
228+ has_default = False
229+
230+ for node in match .nodes :
231+ if node .op == "call_function" :
232+ current_stream = node .meta .get ("stream_label" )
233+ if current_stream is None :
234+ has_default = True
235+ else :
236+ non_default_streams .add (current_stream )
237+ if len (non_default_streams ) > 1 :
238+ logger .debug (
239+ f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
240+ f"Multiple streams found: { non_default_streams } . "
241+ f"Fusion is not supported for cross-stream operations."
242+ )
243+ return False
244+
245+ if has_default and len (non_default_streams ) > 0 :
246+ logger .debug (
247+ f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
248+ f"Multiple streams found: { non_default_streams } . "
249+ f"Fusion is not supported for cross-stream operations." )
250+ return False
251+
252+ return True
253+
254+ def pattern (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
255+ rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
256+ offset : torch .Tensor ):
257+ """
258+ Pattern for AddRMSNormQuantSPPattern fusion.
259+ """
260+ output = torch .ops .npu .npu_add_rms_norm (rms_norm_input , residual ,
261+ rms_norm_weight , epsilon )
262+ out0 = output [0 ]
263+ out1 = output [2 ]
264+ out0 = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (out0 , True )
265+ quantized_output = torch .ops .npu .npu_quantize (out0 , scale , offset ,
266+ torch .qint8 , - 1 , False )
267+ return quantized_output , out1
268+
269+ def replacement (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
270+ rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
271+ offset : torch .Tensor ):
272+ """
273+ Replacement for the AddRMSNormQuantSPPattern fusion.
274+ """
275+ output = torch .ops .npu .npu_add_rms_norm_quant (
276+ rms_norm_input ,
277+ residual ,
278+ rms_norm_weight ,
279+ # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
280+ 1. / scale ,
281+ offset ,
282+ epsilon = epsilon )
283+ quantized_output = output [0 ]
284+ out1 = output [2 ]
285+ quantized_output = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
286+ quantized_output , True )
287+ return quantized_output , out1
288+
289+ def get_inputs ():
290+ """
291+ Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern.
292+ """
293+ rms_norm_input = torch .randn (2 , 4 , device = "npu" )
294+ residual = torch .randn (2 , 4 , device = "npu" )
295+ rms_norm_weight = torch .randn (4 , device = "npu" )
296+ scale = torch .ones (4 , device = "npu" )
297+ offset = torch .zeros (4 , device = "npu" )
298+ return [rms_norm_input , residual , rms_norm_weight , scale , offset ]
299+
300+ import torchair
301+
302+ torchair .register_replacement (search_fn = pattern ,
303+ replace_fn = replacement ,
304+ example_inputs = get_inputs (),
305+ extra_check = _extra_stream_scope_check )
306+
307+
308+ @functools .lru_cache (None )
309+ # The replacement registered here will be actually executed after AOT.
310+ def replacement_add_rms_norm_quant_sp_pattern_with_bias (epsilon ):
311+ if 'torch_npu' not in sys .modules :
312+ logger .info (
313+ 'The AddRMSNormQuantSPPatternWithBias fusion will only be enabled in a torch npu env.'
314+ 'When there is no torch_npu in the env, skip fusion.' )
315+ return
316+
317+ def _extra_stream_scope_check (match : Match ) -> bool :
318+ """
319+ Checks if all nodes in the same stream.
320+ """
321+ non_default_streams = set ()
322+ has_default = False
323+
324+ for node in match .nodes :
325+ if node .op == "call_function" :
326+ current_stream = node .meta .get ("stream_label" )
327+ if current_stream is None :
328+ has_default = True
329+ else :
330+ non_default_streams .add (current_stream )
331+ if len (non_default_streams ) > 1 :
332+ logger .debug (
333+ f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
334+ f"Multiple streams found: { non_default_streams } . "
335+ f"Fusion is not supported for cross-stream operations."
336+ )
337+ return False
338+
339+ if has_default and len (non_default_streams ) > 0 :
340+ logger .debug (
341+ f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. "
342+ f"Multiple streams found: { non_default_streams } . "
343+ f"Fusion is not supported for cross-stream operations." )
344+ return False
345+
346+ return True
347+
348+ def pattern (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
349+ rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
350+ offset : torch .Tensor , bias : torch .Tensor ):
351+ """
352+ Pattern for AddRMSNormQuantSPPatternWithBias fusion.
353+ """
354+ output = torch .ops .npu .npu_add_rms_norm (rms_norm_input , residual ,
355+ rms_norm_weight , epsilon )
356+ out0 = output [0 ]
357+ out1 = output [2 ]
358+ out0 = out0 + bias
359+ out0 = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (out0 , True )
360+ quantized_output = torch .ops .npu .npu_quantize (out0 , scale , offset ,
361+ torch .qint8 , - 1 , False )
362+ return quantized_output , out1
363+
364+ def replacement (rms_norm_input : torch .Tensor , residual : torch .Tensor ,
365+ rms_norm_weight : torch .Tensor , scale : torch .Tensor ,
366+ offset : torch .Tensor , bias : torch .Tensor ):
367+ """
368+ Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
369+ """
370+ output = torch .ops .npu .npu_add_rms_norm_quant (
371+ rms_norm_input ,
372+ residual ,
373+ rms_norm_weight ,
374+ # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
375+ 1. / scale ,
376+ offset ,
377+ epsilon = epsilon ,
378+ beta = bias )
379+ quantized_output = output [0 ]
380+ out1 = output [2 ]
381+ quantized_output = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
382+ quantized_output , True )
383+ return quantized_output , out1
384+
385+ def get_inputs ():
386+ """
387+ Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern.
388+ """
389+ rms_norm_input = torch .randn (2 , 4 , device = "npu" )
390+ residual = torch .randn (2 , 4 , device = "npu" )
391+ rms_norm_weight = torch .randn (4 , device = "npu" )
392+ rmsnorm_bias = torch .randn (4 , device = "npu" )
393+ scale = torch .ones (4 , device = "npu" )
394+ offset = torch .zeros (4 , device = "npu" )
395+ return [
396+ rms_norm_input , residual , rms_norm_weight , scale , offset ,
397+ rmsnorm_bias
398+ ]
399+
400+ import torchair
401+
402+ torchair .register_replacement (search_fn = pattern ,
403+ replace_fn = replacement ,
404+ example_inputs = get_inputs (),
405+ extra_check = _extra_stream_scope_check )
406+
407+
214408# register converter for pass
215409common_epsilons = [1e-5 , 1e-6 ]
216410for eps in common_epsilons :
@@ -219,3 +413,5 @@ def get_inputs():
219413 )
220414 replacement_add_rms_norm_quant (eps )
221415 replacement_add_rms_norm_quant_with_bias (eps )
416+ replacement_add_rms_norm_quant_sp_pattern (eps )
417+ replacement_add_rms_norm_quant_sp_pattern_with_bias (eps )
0 commit comments