Skip to content

Commit 6fe41c2

Browse files
authored
config migration: int* (#1696)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 2e51872 commit 6fe41c2

File tree

4 files changed

+173
-119
lines changed

4 files changed

+173
-119
lines changed

test/dtypes/test_affine_quantized.py

+1
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def test_flatten_unflatten(self, device, dtype):
218218
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
219219
if isinstance(apply_quant, AOBaseConfig):
220220
quantize_(linear, apply_quant)
221+
ql = linear
221222
else:
222223
# TODO(#1690): delete this once config migration is done
223224
ql = apply_quant(linear)

test/quantization/test_quant_api.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
float8_dynamic_activation_float8_weight,
3434
float8_static_activation_float8_weight,
3535
float8_weight_only,
36+
int4_dynamic_activation_int4_weight,
3637
int4_weight_only,
3738
int8_dynamic_activation_int4_weight,
3839
int8_dynamic_activation_int8_weight,
@@ -50,6 +51,7 @@
5051
TORCH_VERSION_AT_LEAST_2_5,
5152
TORCH_VERSION_AT_LEAST_2_6,
5253
is_sm_at_least_89,
54+
is_sm_at_least_90,
5355
unwrap_tensor_subclass,
5456
)
5557

@@ -798,6 +800,10 @@ def test_int4wo_cpu(self, dtype, x_dim):
798800
float8_weight_only(),
799801
float8_dynamic_activation_float8_weight(),
800802
float8_static_activation_float8_weight(scale=torch.tensor([1.0])),
803+
int4_dynamic_activation_int4_weight(),
804+
int8_dynamic_activation_int8_weight(),
805+
int8_dynamic_activation_int4_weight(),
806+
int8_weight_only(),
801807
],
802808
)
803809
def test_workflow_e2e_numerics(self, config):
@@ -816,6 +822,11 @@ def test_workflow_e2e_numerics(self, config):
816822
and not is_sm_at_least_89()
817823
):
818824
return unittest.skip("requires CUDA capability 8.9 or greater")
825+
elif (
826+
isinstance(config, int4_dynamic_activation_int4_weight)
827+
and is_sm_at_least_90()
828+
):
829+
return unittest.skip("only supported on CUDA capability 8.9, not greater")
819830

820831
# scale has to be moved to cuda here because the parametrization init
821832
# code happens before gating for cuda availability
@@ -837,7 +848,7 @@ def test_workflow_e2e_numerics(self, config):
837848
y_q = m_q(x)
838849

839850
sqnr = compute_error(y_ref, y_q)
840-
assert sqnr >= 20, f"SQNR {sqnr} is too low"
851+
assert sqnr >= 16.5, f"SQNR {sqnr} is too low"
841852

842853

843854
class TestMultiTensorFlow(TestCase):

torchao/quantization/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@
4949
Float8DynamicActivationFloat8WeightConfig,
5050
Float8StaticActivationFloat8WeightConfig,
5151
Float8WeightOnlyConfig,
52+
Int4DynamicActivationInt4WeightConfig,
5253
Int4WeightOnlyConfig,
54+
Int8DynamicActivationInt4WeightConfig,
55+
Int8DynamicActivationInt8WeightConfig,
56+
Int8WeightOnlyConfig,
5357
float8_dynamic_activation_float8_weight,
5458
float8_static_activation_float8_weight,
5559
float8_weight_only,
@@ -123,7 +127,11 @@
123127
"fpx_weight_only",
124128
"gemlite_uintx_weight_only",
125129
"swap_conv2d_1x1_to_linear",
130+
"Int4DynamicActivationInt4WeightConfig",
131+
"Int8DynamicActivationInt4WeightConfig",
132+
"Int8DynamicActivationInt8WeightConfig",
126133
"Int4WeightOnlyConfig",
134+
"Int8WeightOnlyConfig",
127135
"Float8WeightOnlyConfig",
128136
"Float8DynamicActivationFloat8WeightConfig",
129137
"Float8StaticActivationFloat8WeightConfig",

0 commit comments

Comments
 (0)