|
29 | 29 | AffineQuantizedTensor,
|
30 | 30 | Int4CPULayout,
|
31 | 31 | Int4XPULayout,
|
| 32 | + Int8DynamicActInt4WeightCPULayout, |
32 | 33 | PlainLayout,
|
33 | 34 | QDQLayout,
|
34 | 35 | TensorCoreTiledLayout,
|
|
70 | 71 | TORCH_VERSION_AT_LEAST_2_4,
|
71 | 72 | TORCH_VERSION_AT_LEAST_2_5,
|
72 | 73 | TORCH_VERSION_AT_LEAST_2_6,
|
| 74 | + TORCH_VERSION_AT_LEAST_2_7, |
73 | 75 | TORCH_VERSION_AT_LEAST_2_8,
|
74 | 76 | is_sm_at_least_89,
|
75 | 77 | is_sm_at_least_90,
|
@@ -695,6 +697,72 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
|
695 | 697 | assert "_weight_int4pack_mm_for_cpu" in code[0]
|
696 | 698 | assert "aten.mm.default" not in code[0]
|
697 | 699 |
|
| 700 | + @unittest.skipIf( |
| 701 | + "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), |
| 702 | + reason="cpp kernels not built", |
| 703 | + ) |
| 704 | + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+") |
| 705 | + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) |
| 706 | + @common_utils.parametrize("x_dim", [2, 3]) |
| 707 | + @common_utils.parametrize("bias", [True, False]) |
| 708 | + @common_utils.parametrize("bs", [1, 160]) |
| 709 | + @common_utils.parametrize("sym_quant_a", [True, False]) |
| 710 | + def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): |
| 711 | + if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8: |
| 712 | + # not supported until PT 2.8 |
| 713 | + return |
| 714 | + device = "cpu" |
| 715 | + m = ToyLinearModel(bias=bias).eval().to(dtype).to(device) |
| 716 | + m2 = copy.deepcopy(m) |
| 717 | + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) |
| 718 | + if x_dim == 3: |
| 719 | + example_inputs = (example_inputs[0].unsqueeze(0),) |
| 720 | + |
| 721 | + with torch.no_grad(): |
| 722 | + # Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout |
| 723 | + # is that the former packs two int4 weights into one int8, while the latter does not. |
| 724 | + quantize_( |
| 725 | + m, |
| 726 | + Int8DynamicActivationInt4WeightConfig( |
| 727 | + group_size=32, |
| 728 | + layout=Int8DynamicActInt4WeightCPULayout(), |
| 729 | + act_mapping_type=MappingType.SYMMETRIC |
| 730 | + if sym_quant_a |
| 731 | + else MappingType.ASYMMETRIC, |
| 732 | + ), |
| 733 | + ) |
| 734 | + y, code = torch._inductor.utils.run_and_get_code( |
| 735 | + torch.compile(m, fullgraph=True, dynamic=True), |
| 736 | + *example_inputs, |
| 737 | + ) |
| 738 | + # ensure the expected op is in the code |
| 739 | + assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0] |
| 740 | + quantize_( |
| 741 | + m2, |
| 742 | + int8_dynamic_activation_int4_weight( |
| 743 | + group_size=32, |
| 744 | + layout=PlainLayout(), |
| 745 | + act_mapping_type=MappingType.SYMMETRIC |
| 746 | + if sym_quant_a |
| 747 | + else MappingType.ASYMMETRIC, |
| 748 | + ), |
| 749 | + ) |
| 750 | + torch._dynamo.reset() # may segfault without this |
| 751 | + y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) |
| 752 | + atol, rtol = 4e-7, 1e-5 |
| 753 | + if dtype == torch.bfloat16: |
| 754 | + atol, rtol = 1e-2, 3e-3 |
| 755 | + elif dtype == torch.half: |
| 756 | + atol, rtol = 6e-3, 2e-3 |
| 757 | + assert torch.allclose(y, y2, atol=atol, rtol=rtol) |
| 758 | + # Test get_plain by dequantize() |
| 759 | + dqw1 = m.linear1.weight.original_weight_tensor.dequantize() |
| 760 | + dqw2 = m.linear2.weight.original_weight_tensor.dequantize() |
| 761 | + dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() |
| 762 | + dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() |
| 763 | + assert torch.allclose(dqw1, dqw1_ref) |
| 764 | + assert torch.allclose(dqw2, dqw2_ref) |
| 765 | + |
698 | 766 | # TODO(#1690): move to new config names
|
699 | 767 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
|
700 | 768 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
|
|
0 commit comments