Skip to content

Commit 792f1c4

Browse files
albanDpytorchmergebot
authored andcommitted
No actual change, just remove variable contain Tensors from global scope (pytorch#143225)
Pull Request resolved: pytorch#143225 Approved by: https://github.com/ezyang
1 parent afa313e commit 792f1c4

12 files changed

+2156
-2137
lines changed

test/jit/test_complexity.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from torch.testing._internal.jit_metaprogramming_utils import (
2121
get_all_nn_module_tests,
2222
get_nn_functional_compiled_fn_and_inputs,
23+
get_nn_functional_tests,
2324
get_nn_mod_test_name,
24-
nn_functional_tests,
2525
try_get_nn_module_compiled_mod_and_inputs,
2626
)
2727
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
@@ -70,7 +70,7 @@ def tearDown(self):
7070
def test_generated_functional_tests(self):
7171
with enable_profiling_mode():
7272
stats = [("Name", "Ifs/Loops", "non-tensor ops")]
73-
for test in nn_functional_tests:
73+
for test in get_nn_functional_tests():
7474
test_name = test[0]
7575

7676
fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test)

test/test_cpp_api_parity.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class TestCppApiParity(common.TestCase):
4242
(sample_module.module_tests, common_nn.NewModuleTest),
4343
(sample_functional.functional_tests, common_nn.NewModuleTest),
4444
(common_nn.module_tests, common_nn.NewModuleTest),
45-
(common_nn.new_module_tests, common_nn.NewModuleTest),
45+
(common_nn.get_new_module_tests(), common_nn.NewModuleTest),
4646
(common_nn.criterion_tests, common_nn.CriterionTest),
4747
]:
4848
for test_params_dict in test_params_dicts:

test/test_expanded_weights.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
)
2626
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
2727
from torch.testing._internal.common_modules import module_db, modules
28-
from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase
28+
from torch.testing._internal.common_nn import (
29+
get_new_module_tests,
30+
module_tests,
31+
TestBase,
32+
)
2933
from torch.testing._internal.common_utils import (
3034
freeze_rng_state,
3135
make_tensor,
@@ -1011,7 +1015,7 @@ def filter_supported_tests(t):
10111015
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
10121016
# These currently use the legacy nn tests
10131017
supported_tests = [
1014-
t for t in module_tests + new_module_tests if filter_supported_tests(t)
1018+
t for t in module_tests + get_new_module_tests() if filter_supported_tests(t)
10151019
]
10161020
for test_param in supported_tests:
10171021
if "constructor" not in test_param:

test/test_fx_experimental.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
ops,
5151
)
5252
from torch.testing._internal.common_methods_invocations import op_db
53-
from torch.testing._internal.common_nn import module_tests, new_module_tests
53+
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
5454
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
5555
from torch.testing._internal.jit_utils import JitTestCase
5656
import torch.utils._pytree as pytree
@@ -1006,7 +1006,7 @@ def test_normalize_modules_exhaustive(self):
10061006
Exhaustively test `Node.normalized_arguments` on all standard
10071007
torch.nn Module classes
10081008
"""
1009-
for test_params in module_tests + new_module_tests:
1009+
for test_params in module_tests + get_new_module_tests():
10101010
if "constructor" not in test_params:
10111011
constructor = getattr(torch.nn, test_params["module_name"])
10121012
else:

test/test_jit.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@
107107
from torch.testing._internal.jit_metaprogramming_utils import (
108108
get_script_args,
109109
create_input, unpack_variables,
110-
additional_module_tests, EXCLUDE_SCRIPT_MODULES,
110+
get_all_nn_module_tests, EXCLUDE_SCRIPT_MODULES,
111111
get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
112112

113-
from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests
113+
from torch.testing._internal.common_nn import criterion_tests
114114

115115
# For testing truediv in python 2
116116
from torch.testing._internal.test_module.future_div import div_int_future, div_float_future
@@ -16247,7 +16247,7 @@ def test_version(self):
1624716247
# issue gh-32561
1624816248
self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
1624916249

16250-
for test in module_tests + new_module_tests + additional_module_tests:
16250+
for test in get_all_nn_module_tests():
1625116251
add_nn_module_test(**test)
1625216252

1625316253
for test in criterion_tests:

test/test_nn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
3939
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
4040
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
41-
ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
41+
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
4242
from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
4343
dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
4444
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
@@ -7332,7 +7332,7 @@ def with_tf32_on(self, test=test, kwargs=kwargs):
73327332
else:
73337333
add(cuda_test_name, with_tf32_off)
73347334

7335-
for test_params in module_tests + new_module_tests:
7335+
for test_params in module_tests + get_new_module_tests():
73367336
# TODO: CUDA is not implemented yet
73377337
if 'constructor' not in test_params:
73387338
name = test_params.pop('module_name')

torch/ao/quantization/pt2e/qat_utils.py

+42-23
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
2020

2121
from .utils import (
22-
_conv1d_bn_example_inputs,
23-
_conv2d_bn_example_inputs,
2422
_get_aten_graph_module_for_pattern,
2523
_is_bn_node,
2624
_is_conv_or_conv_transpose_node,
@@ -35,27 +33,6 @@
3533
__all__ = [] # type: ignore[var-annotated]
3634

3735

38-
# Example inputs for quantized and folded conv-bn1d patterns used in convert
39-
_quantized_conv1d_bn_example_inputs = (
40-
torch.randn(1, 1, 3), # x
41-
torch.randn(1, 1, 1), # conv_weight
42-
torch.randn(1), # bn_weight
43-
torch.randn(1), # bn_bias
44-
torch.randn(1), # bn_running_mean
45-
torch.randn(1), # bn_running_var
46-
)
47-
48-
# Example inputs for quantized and folded conv-bn2d patterns used in convert
49-
_quantized_conv2d_bn_example_inputs = (
50-
torch.randn(1, 1, 3, 3), # x
51-
torch.randn(1, 1, 1, 1), # conv_weight
52-
torch.randn(1), # bn_weight
53-
torch.randn(1), # bn_bias
54-
torch.randn(1), # bn_running_mean
55-
torch.randn(1), # bn_running_var
56-
)
57-
58-
5936
def _get_quantized_conv_bn_example_inputs_kwargs(
6037
is_per_channel: bool,
6138
has_bias: bool,
@@ -631,6 +608,28 @@ def _get_new_qspec(qspec: QuantizationSpecBase):
631608

632609

633610
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
611+
# Example inputs for conv-bn1d patterns
612+
_conv1d_bn_example_inputs = (
613+
torch.randn(1, 1, 3), # x
614+
torch.randn(1, 1, 1), # conv_weight
615+
torch.randn(1), # conv_bias
616+
torch.randn(1), # bn_weight
617+
torch.randn(1), # bn_bias
618+
torch.randn(1), # bn_running_mean
619+
torch.randn(1), # bn_running_var
620+
)
621+
622+
# Example inputs for conv-bn2d patterns
623+
_conv2d_bn_example_inputs = (
624+
torch.randn(1, 1, 3, 3), # x
625+
torch.randn(1, 1, 1, 1), # conv_weight
626+
torch.randn(1), # conv_bias
627+
torch.randn(1), # bn_weight
628+
torch.randn(1), # bn_bias
629+
torch.randn(1), # bn_running_mean
630+
torch.randn(1), # bn_running_var
631+
)
632+
634633
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
635634
if not has_bn:
636635
return m
@@ -859,6 +858,26 @@ def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
859858

860859

861860
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
861+
# Example inputs for quantized and folded conv-bn1d patterns used in convert
862+
_quantized_conv1d_bn_example_inputs = (
863+
torch.randn(1, 1, 3), # x
864+
torch.randn(1, 1, 1), # conv_weight
865+
torch.randn(1), # bn_weight
866+
torch.randn(1), # bn_bias
867+
torch.randn(1), # bn_running_mean
868+
torch.randn(1), # bn_running_var
869+
)
870+
871+
# Example inputs for quantized and folded conv-bn2d patterns used in convert
872+
_quantized_conv2d_bn_example_inputs = (
873+
torch.randn(1, 1, 3, 3), # x
874+
torch.randn(1, 1, 1, 1), # conv_weight
875+
torch.randn(1), # bn_weight
876+
torch.randn(1), # bn_bias
877+
torch.randn(1), # bn_running_mean
878+
torch.randn(1), # bn_running_var
879+
)
880+
862881
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
863882
if not has_bn:
864883
return m

0 commit comments

Comments
 (0)