Skip to content

Commit 6ccc4cb

Browse files
committed
config migration: float*
Summary: TODO write me Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f096aa90b0615af5a09f22a5964fd74f4f28f57c ghstack-comment-id: 2649492752 Pull Request resolved: #1694
1 parent 1c9e446 commit 6ccc4cb

File tree

3 files changed

+189
-99
lines changed

3 files changed

+189
-99
lines changed

test/dtypes/test_affine_quantized.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant):
123123
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
124124
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
125125
def test_to_device(self, apply_quant):
126+
def _apply(module, config_or_subclass_inserter):
127+
if isinstance(config_or_subclass_inserter, AOBaseConfig):
128+
quantize_(module, config_or_subclass_inserter)
129+
else:
130+
# TODO(#1690): delete this once config migration is done
131+
module = config_or_subclass_inserter(module)
132+
return module
133+
126134
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
127-
ql = apply_quant(linear)
135+
ql = _apply(linear, apply_quant)
128136
ql.to("cuda")
129137

130138
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
131-
ql = apply_quant(linear)
139+
ql = _apply(linear, apply_quant)
132140
ql.to(device="cuda")
133141

134142
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
135-
ql = apply_quant(linear)
143+
ql = _apply(linear, apply_quant)
136144
ql.cuda()
137145

138146
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

test/quantization/test_quant_api.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
Quantizer,
3131
TwoStepQuantizer,
3232
_replace_with_custom_fn_if_matches_filter,
33+
float8_dynamic_activation_float8_weight,
34+
float8_static_activation_float8_weight,
35+
float8_weight_only,
3336
int4_weight_only,
3437
int8_dynamic_activation_int4_weight,
3538
int8_dynamic_activation_int8_weight,
@@ -46,6 +49,7 @@
4649
TORCH_VERSION_AT_LEAST_2_4,
4750
TORCH_VERSION_AT_LEAST_2_5,
4851
TORCH_VERSION_AT_LEAST_2_6,
52+
is_sm_at_least_89,
4953
unwrap_tensor_subclass,
5054
)
5155

@@ -784,28 +788,52 @@ def test_int4wo_cpu(self, dtype, x_dim):
784788
assert "_weight_int4pack_mm_for_cpu" in code[0]
785789
assert "aten.mm.default" not in code[0]
786790

791+
# TODO(#1690): move to new config names
787792
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
788793
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
789-
def test_int4_weight_only_numerics(self):
794+
@common_utils.parametrize(
795+
"config",
796+
[
797+
int4_weight_only(),
798+
float8_weight_only(),
799+
float8_dynamic_activation_float8_weight(),
800+
float8_static_activation_float8_weight(
801+
scale=torch.tensor([1.0], device="cuda")
802+
),
803+
],
804+
)
805+
def test_workflow_e2e_numerics(self, config):
790806
"""
791807
Simple test of e2e int4_weight_only workflow, comparing numerics
792808
to a bfloat16 baseline.
793809
"""
810+
if (
811+
isinstance(
812+
config,
813+
(
814+
float8_dynamic_activation_float8_weight,
815+
float8_static_activation_float8_weight,
816+
),
817+
)
818+
and not is_sm_at_least_89()
819+
):
820+
return unittest.skip("requires CUDA capability 8.9 or greater")
821+
794822
# set up inputs
795823
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
796824
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
797825
# is that expected?
798826
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
799-
m_int4_wo = copy.deepcopy(m_ref)
827+
m_q = copy.deepcopy(m_ref)
800828

801829
# quantize
802-
quantize_(m_int4_wo, int4_weight_only())
830+
quantize_(m_q, config)
803831

804832
with torch.no_grad():
805833
y_ref = m_ref(x)
806-
y_int4_wo = m_int4_wo(x)
834+
y_q = m_q(x)
807835

808-
sqnr = compute_error(y_ref, y_int4_wo)
836+
sqnr = compute_error(y_ref, y_q)
809837
assert sqnr >= 20, f"SQNR {sqnr} is too low"
810838

811839

0 commit comments

Comments
 (0)