Skip to content

Commit 59171d8

Browse files
committed
up
1 parent b0167de commit 59171d8

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
import itertools
89
import tempfile
910
import unittest
10-
import itertools
11+
1112
import torch
13+
from torch.testing import FileCheck
1214

1315
from torchao.dtypes import PlainLayout
1416
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
@@ -24,7 +26,6 @@
2426
)
2527
from torchao.quantization.quant_api import quantize_
2628
from torchao.utils import unwrap_tensor_subclass
27-
from torch.testing import FileCheck
2829

2930

3031
class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
@@ -39,11 +40,25 @@ def test_accuracy(self):
3940
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
4041

4142
reference_layout = PlainLayout()
42-
test_layouts = [PackedLinearInt8DynamicActivationIntxWeightLayout(), QDQLayout()]
43-
test_weight_dtypes = [torch.int1, torch.int2, torch.int3, torch.int4, torch.int5, torch.int6, torch.int7, torch.int8]
43+
test_layouts = [
44+
PackedLinearInt8DynamicActivationIntxWeightLayout(),
45+
QDQLayout(),
46+
]
47+
test_weight_dtypes = [
48+
torch.int1,
49+
torch.int2,
50+
torch.int3,
51+
torch.int4,
52+
torch.int5,
53+
torch.int6,
54+
torch.int7,
55+
torch.int8,
56+
]
4457
test_has_weight_zeros = [True, False]
4558
test_granularities = [PerGroup(128), PerRow()]
46-
for layout, weight_dtype, has_weight_zeros, granularity in itertools.product(test_layouts, test_weight_dtypes, test_has_weight_zeros, test_granularities):
59+
for layout, weight_dtype, has_weight_zeros, granularity in itertools.product(
60+
test_layouts, test_weight_dtypes, test_has_weight_zeros, test_granularities
61+
):
4762
quantized_model = copy.deepcopy(model)
4863
quantize_(
4964
quantized_model,
@@ -71,7 +86,9 @@ def test_accuracy(self):
7186
expected_result = quantized_model_reference(activations)
7287
self.assertTrue(torch.allclose(result, expected_result, atol=1e-6))
7388

74-
def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(self):
89+
def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(
90+
self,
91+
):
7592
"""
7693
Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with
7794
torch.export.export, torch.compile, and AOTI.
@@ -126,7 +143,7 @@ def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(s
126143
fn = torch._inductor.aoti_load_package(package_path)
127144
aoti_results = fn(activations)
128145
self.assertTrue(torch.allclose(eager_results, aoti_results))
129-
146+
130147
def test_export_QDQLayout(self):
131148
"""
132149
Checks that models quantized with TestQDQLayout() export as expected
@@ -167,6 +184,3 @@ def test_export_QDQLayout(self):
167184
FileCheck().check_count(line, 1, exactly=True).run(
168185
exported.graph_module.code
169186
)
170-
171-
if __name__ == "__main__":
172-
unittest.main()

0 commit comments

Comments
 (0)