@@ -2385,6 +2385,192 @@ def validate(self, model: torch.fx.GraphModule) -> None:
23852385 node_list ,
23862386 )
23872387
2388+ def test_conv3d_bn_relu (self ):
2389+ class BackendAQuantizer (Quantizer ):
2390+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2391+ act_qspec = QuantizationSpec (
2392+ dtype = torch .uint8 ,
2393+ quant_min = 0 ,
2394+ quant_max = 255 ,
2395+ qscheme = torch .per_tensor_affine ,
2396+ is_dynamic = False ,
2397+ observer_or_fake_quant_ctr = observer .default_observer ,
2398+ )
2399+ weight_qspec = QuantizationSpec (
2400+ dtype = torch .int8 ,
2401+ quant_min = - 128 ,
2402+ quant_max = 127 ,
2403+ qscheme = torch .per_tensor_affine ,
2404+ is_dynamic = False ,
2405+ observer_or_fake_quant_ctr = observer .default_weight_observer ,
2406+ )
2407+ bias_qspec = QuantizationSpec (
2408+ dtype = torch .float32 ,
2409+ is_dynamic = False ,
2410+ observer_or_fake_quant_ctr = observer .PlaceholderObserver ,
2411+ )
2412+ # conv_transpose + bn is fused automatically in PTQ (not configurable)
2413+ # so we just need to annotate conv + relu for conv + bn + relu
2414+ # pattern
2415+ for n in model .graph .nodes :
2416+ if (
2417+ n .op != "call_function"
2418+ or n .target != torch .ops .aten .relu .default
2419+ ):
2420+ continue
2421+ relu_node = n
2422+ n = n .args [0 ]
2423+ if (
2424+ n .op != "call_function"
2425+ and n .target != torch .ops .aten .conv3d .input
2426+ ):
2427+ continue
2428+ conv_t_node = n
2429+ input_act = conv_t_node .args [0 ]
2430+ weight = conv_t_node .args [1 ]
2431+ bias = conv_t_node .args [2 ]
2432+ conv_t_node .meta ["quantization_annotation" ] = (
2433+ QuantizationAnnotation (
2434+ input_qspec_map = {
2435+ input_act : act_qspec ,
2436+ weight : weight_qspec ,
2437+ bias : bias_qspec ,
2438+ },
2439+ _annotated = True ,
2440+ )
2441+ )
2442+ relu_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2443+ output_qspec = act_qspec ,
2444+ _annotated = True ,
2445+ )
2446+
2447+ def validate (self , model : torch .fx .GraphModule ) -> None :
2448+ pass
2449+
2450+ class M (torch .nn .Module ):
2451+ def __init__ (self ):
2452+ super ().__init__ ()
2453+ self .conv = torch .nn .Conv3d (2 , 2 , 3 , padding = 1 )
2454+ self .bn = torch .nn .BatchNorm3d (2 )
2455+
2456+ def forward (self , x ):
2457+ return torch .nn .functional .relu (self .bn (self .conv (x )))
2458+
2459+ example_inputs = (torch .randn (1 , 2 , 2 , 5 , 5 ),)
2460+ node_occurrence = {
2461+ # two for input of the first conv, one for output for the first conv
2462+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2463+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2464+ }
2465+ node_list = [
2466+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2467+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2468+ torch .ops .aten .conv3d .default ,
2469+ torch .ops .aten .relu .default ,
2470+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2471+ ]
2472+ model = M ().eval ()
2473+ self ._test_quantizer (
2474+ model ,
2475+ example_inputs ,
2476+ BackendAQuantizer (),
2477+ node_occurrence ,
2478+ node_list ,
2479+ )
2480+
2481+ def test_conv_transpose3d_bn_relu (self ):
2482+ class BackendAQuantizer (Quantizer ):
2483+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2484+ act_qspec = QuantizationSpec (
2485+ dtype = torch .uint8 ,
2486+ quant_min = 0 ,
2487+ quant_max = 255 ,
2488+ qscheme = torch .per_tensor_affine ,
2489+ is_dynamic = False ,
2490+ observer_or_fake_quant_ctr = observer .default_observer ,
2491+ )
2492+ weight_qspec = QuantizationSpec (
2493+ dtype = torch .int8 ,
2494+ quant_min = - 128 ,
2495+ quant_max = 127 ,
2496+ qscheme = torch .per_tensor_affine ,
2497+ is_dynamic = False ,
2498+ observer_or_fake_quant_ctr = observer .default_weight_observer ,
2499+ )
2500+ bias_qspec = QuantizationSpec (
2501+ dtype = torch .float32 ,
2502+ is_dynamic = False ,
2503+ observer_or_fake_quant_ctr = observer .PlaceholderObserver ,
2504+ )
2505+ # conv_transpose + bn is fused automatically in PTQ (not configurable)
2506+ # so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu
2507+ # pattern
2508+ for n in model .graph .nodes :
2509+ if (
2510+ n .op != "call_function"
2511+ or n .target != torch .ops .aten .relu .default
2512+ ):
2513+ continue
2514+ relu_node = n
2515+ n = n .args [0 ]
2516+ if (
2517+ n .op != "call_function"
2518+ and n .target != torch .ops .aten .conv_transposed3d .input
2519+ ):
2520+ continue
2521+ conv_t_node = n
2522+ input_act = conv_t_node .args [0 ]
2523+ weight = conv_t_node .args [1 ]
2524+ bias = conv_t_node .args [2 ]
2525+ conv_t_node .meta ["quantization_annotation" ] = (
2526+ QuantizationAnnotation (
2527+ input_qspec_map = {
2528+ input_act : act_qspec ,
2529+ weight : weight_qspec ,
2530+ bias : bias_qspec ,
2531+ },
2532+ _annotated = True ,
2533+ )
2534+ )
2535+ relu_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2536+ output_qspec = act_qspec ,
2537+ _annotated = True ,
2538+ )
2539+
2540+ def validate (self , model : torch .fx .GraphModule ) -> None :
2541+ pass
2542+
2543+ class M (torch .nn .Module ):
2544+ def __init__ (self ):
2545+ super ().__init__ ()
2546+ self .conv_t = torch .nn .ConvTranspose3d (2 , 2 , 3 , padding = 1 )
2547+ self .bn = torch .nn .BatchNorm3d (2 )
2548+
2549+ def forward (self , x ):
2550+ return torch .nn .functional .relu (self .bn (self .conv_t (x )))
2551+
2552+ example_inputs = (torch .randn (1 , 2 , 2 , 5 , 5 ),)
2553+ node_occurrence = {
2554+ # two for input of the first conv, one for output for the first conv
2555+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2556+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2557+ }
2558+ node_list = [
2559+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2560+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2561+ torch .ops .aten .conv_transpose3d .input ,
2562+ torch .ops .aten .relu .default ,
2563+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2564+ ]
2565+ model = M ().eval ()
2566+ self ._test_quantizer (
2567+ model ,
2568+ example_inputs ,
2569+ BackendAQuantizer (),
2570+ node_occurrence ,
2571+ node_list ,
2572+ )
2573+
23882574 def test_multi_users_without_output_observer (self ):
23892575 """
23902576 Test the case in which a node is used by multiple users,
0 commit comments