Skip to content

Commit 8b57afe

Browse files
authored
[CPU] Enable DA8W4 on CPU (#2128)
* [CPU] enable int8_dynamic_activation_int4_weight with Int4CPULayout * Fix format issue * Add Int8DynamicActInt4WeightCPULayout * remove dispatch for t() * Add cpp kernel for weight packing and GEMM * Register ATQ linear dispatch for da8w4 linear * Fix issues with torch.compile * Fix DA8W4CPUAQTTensorImpl.get_plain * Test DA8W4CPUAQTTensorImpl.get_plain in UT * Skip UT if CPP kernel not built * Add AVX512_VNNI implementation for small M * improve performance * Support symmetric quantization of activation * Refine code * Refine code * Put in a separate file * Bug fix * refine code
1 parent 8940aa7 commit 8b57afe

File tree

10 files changed

+1296
-30
lines changed

10 files changed

+1296
-30
lines changed

setup.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -385,20 +385,29 @@ def get_extensions():
385385
extra_compile_args["cxx"].extend(
386386
["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"]
387387
)
388-
if (
389-
use_cpu_kernels
390-
and is_linux
391-
and hasattr(torch._C._cpu, "_is_avx512_supported")
392-
and torch._C._cpu._is_avx512_supported()
393-
):
394-
extra_compile_args["cxx"].extend(
395-
[
396-
"-DCPU_CAPABILITY_AVX512",
397-
"-march=native",
398-
"-mfma",
399-
"-fopenmp",
400-
]
401-
)
388+
389+
if use_cpu_kernels and is_linux:
390+
if (
391+
hasattr(torch._C._cpu, "_is_avx512_supported")
392+
and torch._C._cpu._is_avx512_supported()
393+
):
394+
extra_compile_args["cxx"].extend(
395+
[
396+
"-DCPU_CAPABILITY_AVX512",
397+
"-march=native",
398+
"-mfma",
399+
"-fopenmp",
400+
]
401+
)
402+
if (
403+
hasattr(torch._C._cpu, "_is_avx512_vnni_supported")
404+
and torch._C._cpu._is_avx512_vnni_supported()
405+
):
406+
extra_compile_args["cxx"].extend(
407+
[
408+
"-DCPU_CAPABILITY_AVX512_VNNI",
409+
]
410+
)
402411

403412
if debug_mode:
404413
extra_compile_args["cxx"].append("-g")

test/quantization/test_quant_api.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
AffineQuantizedTensor,
3030
Int4CPULayout,
3131
Int4XPULayout,
32+
Int8DynamicActInt4WeightCPULayout,
3233
PlainLayout,
3334
QDQLayout,
3435
TensorCoreTiledLayout,
@@ -70,6 +71,7 @@
7071
TORCH_VERSION_AT_LEAST_2_4,
7172
TORCH_VERSION_AT_LEAST_2_5,
7273
TORCH_VERSION_AT_LEAST_2_6,
74+
TORCH_VERSION_AT_LEAST_2_7,
7375
TORCH_VERSION_AT_LEAST_2_8,
7476
is_sm_at_least_89,
7577
is_sm_at_least_90,
@@ -695,6 +697,72 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
695697
assert "_weight_int4pack_mm_for_cpu" in code[0]
696698
assert "aten.mm.default" not in code[0]
697699

700+
@unittest.skipIf(
701+
"CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"),
702+
reason="cpp kernels not built",
703+
)
704+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+")
705+
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
706+
@common_utils.parametrize("x_dim", [2, 3])
707+
@common_utils.parametrize("bias", [True, False])
708+
@common_utils.parametrize("bs", [1, 160])
709+
@common_utils.parametrize("sym_quant_a", [True, False])
710+
def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a):
711+
if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8:
712+
# not supported until PT 2.8
713+
return
714+
device = "cpu"
715+
m = ToyLinearModel(bias=bias).eval().to(dtype).to(device)
716+
m2 = copy.deepcopy(m)
717+
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
718+
if x_dim == 3:
719+
example_inputs = (example_inputs[0].unsqueeze(0),)
720+
721+
with torch.no_grad():
722+
# Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout
723+
# is that the former packs two int4 weights into one int8, while the latter does not.
724+
quantize_(
725+
m,
726+
Int8DynamicActivationInt4WeightConfig(
727+
group_size=32,
728+
layout=Int8DynamicActInt4WeightCPULayout(),
729+
act_mapping_type=MappingType.SYMMETRIC
730+
if sym_quant_a
731+
else MappingType.ASYMMETRIC,
732+
),
733+
)
734+
y, code = torch._inductor.utils.run_and_get_code(
735+
torch.compile(m, fullgraph=True, dynamic=True),
736+
*example_inputs,
737+
)
738+
# ensure the expected op is in the code
739+
assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0]
740+
quantize_(
741+
m2,
742+
int8_dynamic_activation_int4_weight(
743+
group_size=32,
744+
layout=PlainLayout(),
745+
act_mapping_type=MappingType.SYMMETRIC
746+
if sym_quant_a
747+
else MappingType.ASYMMETRIC,
748+
),
749+
)
750+
torch._dynamo.reset() # may segfault without this
751+
y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs)
752+
atol, rtol = 4e-7, 1e-5
753+
if dtype == torch.bfloat16:
754+
atol, rtol = 1e-2, 3e-3
755+
elif dtype == torch.half:
756+
atol, rtol = 6e-3, 2e-3
757+
assert torch.allclose(y, y2, atol=atol, rtol=rtol)
758+
# Test get_plain by dequantize()
759+
dqw1 = m.linear1.weight.original_weight_tensor.dequantize()
760+
dqw2 = m.linear2.weight.original_weight_tensor.dequantize()
761+
dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize()
762+
dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize()
763+
assert torch.allclose(dqw1, dqw1_ref)
764+
assert torch.allclose(dqw2, dqw2_ref)
765+
698766
# TODO(#1690): move to new config names
699767
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
700768
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

0 commit comments

Comments
 (0)