|
19 | 19 | from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
|
20 | 20 |
|
21 | 21 | from .utils import (
|
22 |
| - _conv1d_bn_example_inputs, |
23 |
| - _conv2d_bn_example_inputs, |
24 | 22 | _get_aten_graph_module_for_pattern,
|
25 | 23 | _is_bn_node,
|
26 | 24 | _is_conv_or_conv_transpose_node,
|
|
35 | 33 | __all__ = [] # type: ignore[var-annotated]
|
36 | 34 |
|
37 | 35 |
|
38 |
| -# Example inputs for quantized and folded conv-bn1d patterns used in convert |
39 |
| -_quantized_conv1d_bn_example_inputs = ( |
40 |
| - torch.randn(1, 1, 3), # x |
41 |
| - torch.randn(1, 1, 1), # conv_weight |
42 |
| - torch.randn(1), # bn_weight |
43 |
| - torch.randn(1), # bn_bias |
44 |
| - torch.randn(1), # bn_running_mean |
45 |
| - torch.randn(1), # bn_running_var |
46 |
| -) |
47 |
| - |
48 |
| -# Example inputs for quantized and folded conv-bn2d patterns used in convert |
49 |
| -_quantized_conv2d_bn_example_inputs = ( |
50 |
| - torch.randn(1, 1, 3, 3), # x |
51 |
| - torch.randn(1, 1, 1, 1), # conv_weight |
52 |
| - torch.randn(1), # bn_weight |
53 |
| - torch.randn(1), # bn_bias |
54 |
| - torch.randn(1), # bn_running_mean |
55 |
| - torch.randn(1), # bn_running_var |
56 |
| -) |
57 |
| - |
58 |
| - |
59 | 36 | def _get_quantized_conv_bn_example_inputs_kwargs(
|
60 | 37 | is_per_channel: bool,
|
61 | 38 | has_bias: bool,
|
@@ -631,6 +608,28 @@ def _get_new_qspec(qspec: QuantizationSpecBase):
|
631 | 608 |
|
632 | 609 |
|
633 | 610 | def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
| 611 | + # Example inputs for conv-bn1d patterns |
| 612 | + _conv1d_bn_example_inputs = ( |
| 613 | + torch.randn(1, 1, 3), # x |
| 614 | + torch.randn(1, 1, 1), # conv_weight |
| 615 | + torch.randn(1), # conv_bias |
| 616 | + torch.randn(1), # bn_weight |
| 617 | + torch.randn(1), # bn_bias |
| 618 | + torch.randn(1), # bn_running_mean |
| 619 | + torch.randn(1), # bn_running_var |
| 620 | + ) |
| 621 | + |
| 622 | + # Example inputs for conv-bn2d patterns |
| 623 | + _conv2d_bn_example_inputs = ( |
| 624 | + torch.randn(1, 1, 3, 3), # x |
| 625 | + torch.randn(1, 1, 1, 1), # conv_weight |
| 626 | + torch.randn(1), # conv_bias |
| 627 | + torch.randn(1), # bn_weight |
| 628 | + torch.randn(1), # bn_bias |
| 629 | + torch.randn(1), # bn_running_mean |
| 630 | + torch.randn(1), # bn_running_var |
| 631 | + ) |
| 632 | + |
634 | 633 | has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
635 | 634 | if not has_bn:
|
636 | 635 | return m
|
@@ -859,6 +858,26 @@ def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
|
859 | 858 |
|
860 | 859 |
|
861 | 860 | def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
| 861 | + # Example inputs for quantized and folded conv-bn1d patterns used in convert |
| 862 | + _quantized_conv1d_bn_example_inputs = ( |
| 863 | + torch.randn(1, 1, 3), # x |
| 864 | + torch.randn(1, 1, 1), # conv_weight |
| 865 | + torch.randn(1), # bn_weight |
| 866 | + torch.randn(1), # bn_bias |
| 867 | + torch.randn(1), # bn_running_mean |
| 868 | + torch.randn(1), # bn_running_var |
| 869 | + ) |
| 870 | + |
| 871 | + # Example inputs for quantized and folded conv-bn2d patterns used in convert |
| 872 | + _quantized_conv2d_bn_example_inputs = ( |
| 873 | + torch.randn(1, 1, 3, 3), # x |
| 874 | + torch.randn(1, 1, 1, 1), # conv_weight |
| 875 | + torch.randn(1), # bn_weight |
| 876 | + torch.randn(1), # bn_bias |
| 877 | + torch.randn(1), # bn_running_mean |
| 878 | + torch.randn(1), # bn_running_var |
| 879 | + ) |
| 880 | + |
862 | 881 | has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
863 | 882 | if not has_bn:
|
864 | 883 | return m
|
|
0 commit comments