5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import copy
8
+ import itertools
8
9
import tempfile
9
10
import unittest
10
- import itertools
11
+
11
12
import torch
13
+ from torch .testing import FileCheck
12
14
13
15
from torchao .dtypes import PlainLayout
14
16
from torchao .experimental .packed_linear_int8_dynamic_activation_intx_weight_layout import (
24
26
)
25
27
from torchao .quantization .quant_api import quantize_
26
28
from torchao .utils import unwrap_tensor_subclass
27
- from torch .testing import FileCheck
28
29
29
30
30
31
class TestInt8DynamicActivationIntxWeight (unittest .TestCase ):
@@ -39,11 +40,25 @@ def test_accuracy(self):
39
40
model = torch .nn .Sequential (* [torch .nn .Linear (k , n , bias = False )])
40
41
41
42
reference_layout = PlainLayout ()
42
- test_layouts = [PackedLinearInt8DynamicActivationIntxWeightLayout (), QDQLayout ()]
43
- test_weight_dtypes = [torch .int1 , torch .int2 , torch .int3 , torch .int4 , torch .int5 , torch .int6 , torch .int7 , torch .int8 ]
43
+ test_layouts = [
44
+ PackedLinearInt8DynamicActivationIntxWeightLayout (),
45
+ QDQLayout (),
46
+ ]
47
+ test_weight_dtypes = [
48
+ torch .int1 ,
49
+ torch .int2 ,
50
+ torch .int3 ,
51
+ torch .int4 ,
52
+ torch .int5 ,
53
+ torch .int6 ,
54
+ torch .int7 ,
55
+ torch .int8 ,
56
+ ]
44
57
test_has_weight_zeros = [True , False ]
45
58
test_granularities = [PerGroup (128 ), PerRow ()]
46
- for layout , weight_dtype , has_weight_zeros , granularity in itertools .product (test_layouts , test_weight_dtypes , test_has_weight_zeros , test_granularities ):
59
+ for layout , weight_dtype , has_weight_zeros , granularity in itertools .product (
60
+ test_layouts , test_weight_dtypes , test_has_weight_zeros , test_granularities
61
+ ):
47
62
quantized_model = copy .deepcopy (model )
48
63
quantize_ (
49
64
quantized_model ,
@@ -71,7 +86,9 @@ def test_accuracy(self):
71
86
expected_result = quantized_model_reference (activations )
72
87
self .assertTrue (torch .allclose (result , expected_result , atol = 1e-6 ))
73
88
74
- def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout (self ):
89
+ def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout (
90
+ self ,
91
+ ):
75
92
"""
76
93
Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with
77
94
torch.export.export, torch.compile, and AOTI.
@@ -126,7 +143,7 @@ def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(s
126
143
fn = torch ._inductor .aoti_load_package (package_path )
127
144
aoti_results = fn (activations )
128
145
self .assertTrue (torch .allclose (eager_results , aoti_results ))
129
-
146
+
130
147
def test_export_QDQLayout (self ):
131
148
"""
132
149
Checks that models quantized with TestQDQLayout() export as expected
@@ -167,6 +184,3 @@ def test_export_QDQLayout(self):
167
184
FileCheck ().check_count (line , 1 , exactly = True ).run (
168
185
exported .graph_module .code
169
186
)
170
-
171
- if __name__ == "__main__" :
172
- unittest .main ()
0 commit comments