Skip to content

Commit c737354

Browse files
committed
config migration: float*
Summary: TODO write me Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 022867b ghstack-comment-id: 2649492752 Pull Request resolved: #1694
1 parent 1ed0394 commit c737354

File tree

4 files changed

+198
-99
lines changed

4 files changed

+198
-99
lines changed

test/dtypes/test_affine_quantized.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant):
123123
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
124124
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
125125
def test_to_device(self, apply_quant):
126+
def _apply(module, config_or_subclass_inserter):
127+
if isinstance(config_or_subclass_inserter, AOBaseConfig):
128+
quantize_(module, config_or_subclass_inserter)
129+
else:
130+
# TODO(#1690): delete this once config migration is done
131+
module = config_or_subclass_inserter(module)
132+
return module
133+
126134
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
127-
ql = apply_quant(linear)
135+
ql = _apply(linear, apply_quant)
128136
ql.to("cuda")
129137

130138
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
131-
ql = apply_quant(linear)
139+
ql = _apply(linear, apply_quant)
132140
ql.to(device="cuda")
133141

134142
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
135-
ql = apply_quant(linear)
143+
ql = _apply(linear, apply_quant)
136144
ql.cuda()
137145

138146
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

test/quantization/test_quant_api.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
Quantizer,
3131
TwoStepQuantizer,
3232
_replace_with_custom_fn_if_matches_filter,
33+
float8_dynamic_activation_float8_weight,
34+
float8_static_activation_float8_weight,
35+
float8_weight_only,
3336
int4_weight_only,
3437
int8_dynamic_activation_int4_weight,
3538
int8_dynamic_activation_int8_weight,
@@ -46,6 +49,7 @@
4649
TORCH_VERSION_AT_LEAST_2_4,
4750
TORCH_VERSION_AT_LEAST_2_5,
4851
TORCH_VERSION_AT_LEAST_2_6,
52+
is_sm_at_least_89,
4953
unwrap_tensor_subclass,
5054
)
5155

@@ -784,28 +788,55 @@ def test_int4wo_cpu(self, dtype, x_dim):
784788
assert "_weight_int4pack_mm_for_cpu" in code[0]
785789
assert "aten.mm.default" not in code[0]
786790

791+
# TODO(#1690): move to new config names
787792
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
788793
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
789-
def test_int4_weight_only_numerics(self):
794+
@common_utils.parametrize(
795+
"config",
796+
[
797+
int4_weight_only(),
798+
float8_weight_only(),
799+
float8_dynamic_activation_float8_weight(),
800+
float8_static_activation_float8_weight(scale=torch.tensor([1.0])),
801+
],
802+
)
803+
def test_workflow_e2e_numerics(self, config):
790804
"""
791805
Simple test of e2e int4_weight_only workflow, comparing numerics
792806
to a bfloat16 baseline.
793807
"""
808+
if (
809+
isinstance(
810+
config,
811+
(
812+
float8_dynamic_activation_float8_weight,
813+
float8_static_activation_float8_weight,
814+
),
815+
)
816+
and not is_sm_at_least_89()
817+
):
818+
return unittest.skip("requires CUDA capability 8.9 or greater")
819+
820+
# scale has to be moved to cuda here because the parametrization init
821+
# code happens before gating for cuda availability
822+
if isinstance(config, float8_static_activation_float8_weight):
823+
config.scale = config.scale.to("cuda")
824+
794825
# set up inputs
795826
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
796827
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
797828
# is that expected?
798829
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
799-
m_int4_wo = copy.deepcopy(m_ref)
830+
m_q = copy.deepcopy(m_ref)
800831

801832
# quantize
802-
quantize_(m_int4_wo, int4_weight_only())
833+
quantize_(m_q, config)
803834

804835
with torch.no_grad():
805836
y_ref = m_ref(x)
806-
y_int4_wo = m_int4_wo(x)
837+
y_q = m_q(x)
807838

808-
sqnr = compute_error(y_ref, y_int4_wo)
839+
sqnr = compute_error(y_ref, y_q)
809840
assert sqnr >= 20, f"SQNR {sqnr} is too low"
810841

811842

torchao/quantization/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
AffineQuantizedObserverBase,
4747
)
4848
from .quant_api import (
49+
Float8DynamicActivationFloat8WeightConfig,
50+
Float8StaticActivationFloat8WeightConfig,
51+
Float8WeightOnlyConfig,
4952
Int4WeightOnlyConfig,
5053
float8_dynamic_activation_float8_weight,
5154
float8_static_activation_float8_weight,
@@ -121,6 +124,9 @@
121124
"gemlite_uintx_weight_only",
122125
"swap_conv2d_1x1_to_linear",
123126
"Int4WeightOnlyConfig",
127+
"Float8WeightOnlyConfig",
128+
"Float8DynamicActivationFloat8WeightConfig",
129+
"Float8StaticActivationFloat8WeightConfig",
124130
# smooth quant - subject to change
125131
"get_scale",
126132
"SmoothFakeDynQuantMixin",

0 commit comments

Comments
 (0)