Skip to content

Commit c52dcd4

Browse files
authored
[1/n] Move a few utils under pt2e/pt2e/ to pt2e/ (#2083)
* [1/n] Move pt2e/pt2e/utils.py (and graph, export, qat utils) to pt2e/ Summary: #2082 Test Plan: pytest test/quantization/pt2e Reviewers: Subscribers: Tasks: Tags: * fix
1 parent 4805efd commit c52dcd4

17 files changed

+649
-657
lines changed

test/quantization/pt2e/test_graph_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch._dynamo as torchdynamo
1313
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, run_tests
1414

15-
from torchao.quantization.pt2e.pt2e.graph_utils import (
15+
from torchao.quantization.pt2e.graph_utils import (
1616
find_sequential_partitions,
1717
get_equivalent_types,
1818
update_equivalent_types_dict,

test/quantization/pt2e/test_numeric_debugger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
generate_numeric_debug_handle,
2323
prepare_for_propagation_comparison,
2424
)
25-
from torchao.quantization.pt2e.pt2e.graph_utils import bfs_trace_with_node_process
25+
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
2626
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2727
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
2828
XNNPACKQuantizer,

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,14 +2077,12 @@ def test_model_is_exported(self):
20772077
exported_gm = export_for_training(m, example_inputs, strict=True).module()
20782078
fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs)
20792079
self.assertTrue(
2080-
torchao.quantization.pt2e.pt2e.export_utils.model_is_exported(exported_gm)
2080+
torchao.quantization.pt2e.export_utils.model_is_exported(exported_gm)
20812081
)
20822082
self.assertFalse(
2083-
torchao.quantization.pt2e.pt2e.export_utils.model_is_exported(fx_traced_gm)
2084-
)
2085-
self.assertFalse(
2086-
torchao.quantization.pt2e.pt2e.export_utils.model_is_exported(m)
2083+
torchao.quantization.pt2e.export_utils.model_is_exported(fx_traced_gm)
20872084
)
2085+
self.assertFalse(torchao.quantization.pt2e.export_utils.model_is_exported(m))
20882086

20892087
def test_reentrant(self):
20902088
"""Test we can safely call quantization apis multiple times"""

torchao/quantization/pt2e/__init__.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
import torch
66
from torch import Tensor
77

8+
from torchao.quantization.pt2e.export_utils import (
9+
_allow_exported_model_train_eval as allow_exported_model_train_eval,
10+
)
11+
from torchao.quantization.pt2e.export_utils import (
12+
_move_exported_model_to_eval as move_exported_model_to_eval,
13+
)
14+
from torchao.quantization.pt2e.export_utils import (
15+
_move_exported_model_to_train as move_exported_model_to_train,
16+
)
817
from torchao.quantization.pt2e.pt2e._numeric_debugger import ( # noqa: F401
918
CUSTOM_KEY,
1019
NUMERIC_DEBUG_HANDLE_KEY,
@@ -13,15 +22,6 @@
1322
generate_numeric_debug_handle,
1423
prepare_for_propagation_comparison,
1524
)
16-
from torchao.quantization.pt2e.pt2e.export_utils import (
17-
_allow_exported_model_train_eval as allow_exported_model_train_eval,
18-
)
19-
from torchao.quantization.pt2e.pt2e.export_utils import (
20-
_move_exported_model_to_eval as move_exported_model_to_eval,
21-
)
22-
from torchao.quantization.pt2e.pt2e.export_utils import (
23-
_move_exported_model_to_train as move_exported_model_to_train,
24-
)
2525

2626
from .fake_quantize import (
2727
FakeQuantize,

torchao/quantization/pt2e/pt2e/_numeric_debugger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.fx import GraphModule, Node
1717
from torch.nn import functional as F
1818

19-
from torchao.quantization.pt2e.pt2e.graph_utils import bfs_trace_with_node_process
19+
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
2020

2121
NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle"
2222
CUSTOM_KEY = "custom"

torchao/quantization/pt2e/pt2e/duplicate_dq_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
from torch.fx.node import map_arg
1313
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1414

15-
from torchao.quantization.pt2e.pt2e.utils import (
16-
_filter_sym_size_users,
15+
from torchao.quantization.pt2e.quantizer.utils import (
1716
_is_valid_annotation,
1817
)
18+
from torchao.quantization.pt2e.utils import (
19+
_filter_sym_size_users,
20+
)
1921

2022
logger = logging.getLogger(__name__)
2123
logger.setLevel(logging.WARNING)

torchao/quantization/pt2e/pt2e/port_metadata_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
from torch._export.error import InternalError
1313
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1414

15-
from torchao.quantization.pt2e.pt2e.utils import (
15+
from torchao.quantization.pt2e.quantizer import QuantizationSpecBase
16+
from torchao.quantization.pt2e.quantizer.utils import (
17+
_is_valid_annotation,
18+
)
19+
from torchao.quantization.pt2e.utils import (
1620
_filter_sym_size_users,
1721
_find_q_dq_node_for_user,
18-
_is_valid_annotation,
1922
)
20-
from torchao.quantization.pt2e.quantizer import QuantizationSpecBase
2123
from torchao.quantization.quant_primitives import quant_lib # noqa: F401
2224
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2325

torchao/quantization/pt2e/pt2e/representation/rewrite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from torch.fx import GraphModule
1717
from torch.fx.subgraph_rewriter import replace_pattern
1818

19-
from torchao.quantization.pt2e.pt2e.export_utils import _WrapperModule
20-
from torchao.quantization.pt2e.pt2e.utils import (
19+
from torchao.quantization.pt2e.export_utils import _WrapperModule
20+
from torchao.quantization.pt2e.utils import (
2121
_get_aten_graph_module_for_pattern,
2222
_replace_literals_with_existing_placeholders,
2323
_replace_literals_with_new_placeholders,

0 commit comments

Comments
 (0)