Skip to content

Commit 8940aa7

Browse files
authored
[float8] Prevent quantize_affine_float8/dequantize_affine_float8 decomposed on inductor (#2379)
* quantize_affine_float8/dequantize_affine_float8 not decomposed on inductor * remove redundant unittest.skipIf * fix rebase issue * change dispatch key to a flag decomposed * To be more explicit, use name inductor_decomposed instead * Change ut path
1 parent 4ebc9c0 commit 8940aa7

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,46 @@ def test_preprocess_scale_3d_reshape(self):
675675
expected_shape = (8, 1) # Flattened (2*2*2, 1)
676676
self.assertEqual(result.shape, expected_shape)
677677

678+
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
679+
@common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
680+
def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
681+
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
682+
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
683+
input = torch.randn(10, 10)
684+
with torch.no_grad():
685+
torch._dynamo.reset()
686+
expected_scale = torch.tensor(2.0)
687+
expected_quantized = quantize_affine_float8(
688+
input,
689+
expected_scale,
690+
float8_dtype=float8_dtype,
691+
)
692+
expected_dequantized = dequantize_affine_float8(
693+
expected_quantized,
694+
expected_scale,
695+
output_dtype=hp_dtype,
696+
)
697+
test_q, (code_q,) = torch._inductor.utils.run_and_get_code(
698+
torch.compile(quantize_affine_float8),
699+
input,
700+
expected_scale,
701+
float8_dtype=float8_dtype,
702+
)
703+
torch.testing.FileCheck().check(
704+
"torch.ops.torchao.quantize_affine_float8.default"
705+
).run(code_q)
706+
test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code(
707+
torch.compile(dequantize_affine_float8),
708+
test_q,
709+
expected_scale,
710+
hp_dtype,
711+
)
712+
torch.testing.FileCheck().check(
713+
"torch.ops.torchao.dequantize_affine_float8.default"
714+
).run(code_dq)
715+
torch.testing.assert_close(expected_quantized, test_q)
716+
torch.testing.assert_close(expected_dequantized, test_dq)
717+
678718

679719
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
680720

torchao/quantization/quant_primitives.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,6 +2270,7 @@ def _expand_scale_to_tensor_shape(
22702270
return expanded_scale
22712271

22722272

2273+
@_register_custom_op(quant_lib, False)
22732274
def _quantize_affine_float8(
22742275
tensor: torch.Tensor,
22752276
scale: torch.Tensor,
@@ -2290,6 +2291,16 @@ def _quantize_affine_float8(
22902291
return fp8_tensor
22912292

22922293

2294+
@torch.library.impl(quant_lib, "quantize_affine_float8", "Meta")
2295+
def _quantize_affine_float8_meta(
2296+
tensor: torch.Tensor,
2297+
scale: torch.Tensor,
2298+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
2299+
) -> torch.Tensor:
2300+
return torch.empty_like(tensor, dtype=float8_dtype)
2301+
2302+
2303+
@_register_custom_op(quant_lib, False)
22932304
def _dequantize_affine_float8(
22942305
tensor: torch.Tensor,
22952306
scale: torch.Tensor,
@@ -2305,3 +2316,12 @@ def _dequantize_affine_float8(
23052316

23062317
hp_tensor = fp8_tensor * scale_expanded
23072318
return hp_tensor.to(output_dtype)
2319+
2320+
2321+
@torch.library.impl(quant_lib, "dequantize_affine_float8", "Meta")
2322+
def _dequantize_affine_float8_meta(
2323+
tensor: torch.Tensor,
2324+
scale: torch.Tensor,
2325+
output_dtype: torch.dtype = torch.float32,
2326+
) -> torch.Tensor:
2327+
return torch.empty_like(tensor, dtype=output_dtype)

torchao/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int:
179179
return n + k - (n % k)
180180

181181

182-
def _register_custom_op(lib):
182+
def _register_custom_op(lib, inductor_decomposed=True):
183183
"""This decorator is used to preserve some high level operators for torch.export.export
184184
while still allow them to be decomposed for inductor path
185185
@@ -206,6 +206,12 @@ def _the_op_that_needs_to_be_preserved(...)
206206
"""
207207
from torch._inductor.decomposition import register_decomposition
208208

209+
dispatch_key = (
210+
"CompositeImplicitAutograd"
211+
if inductor_decomposed
212+
else "CompositeExplicitAutograd"
213+
)
214+
209215
def decorator(fn):
210216
if TORCH_VERSION_AT_LEAST_2_5:
211217
from torch._library.infer_schema import infer_schema
@@ -221,11 +227,12 @@ def decorator(fn):
221227
op_name = fn.__name__[1:]
222228
schema = op_name + infer_schema(fn, mutates_args={})
223229
lib.define(schema)
224-
lib.impl(op_name, fn, "CompositeImplicitAutograd")
230+
lib.impl(op_name, fn, dispatch_key)
225231

226232
lib_namespace = lib.ns
227233
op = getattr(getattr(torch.ops, lib_namespace), op_name)
228-
register_decomposition([op])(fn)
234+
if inductor_decomposed:
235+
register_decomposition([op])(fn)
229236
return op
230237
else:
231238
return fn

0 commit comments

Comments
 (0)