|
30 | 30 | Quantizer,
|
31 | 31 | TwoStepQuantizer,
|
32 | 32 | _replace_with_custom_fn_if_matches_filter,
|
| 33 | + float8_dynamic_activation_float8_weight, |
| 34 | + float8_static_activation_float8_weight, |
| 35 | + float8_weight_only, |
33 | 36 | int4_weight_only,
|
34 | 37 | int8_dynamic_activation_int4_weight,
|
35 | 38 | int8_dynamic_activation_int8_weight,
|
|
46 | 49 | TORCH_VERSION_AT_LEAST_2_4,
|
47 | 50 | TORCH_VERSION_AT_LEAST_2_5,
|
48 | 51 | TORCH_VERSION_AT_LEAST_2_6,
|
| 52 | + is_sm_at_least_89, |
49 | 53 | unwrap_tensor_subclass,
|
50 | 54 | )
|
51 | 55 |
|
@@ -784,28 +788,52 @@ def test_int4wo_cpu(self, dtype, x_dim):
|
784 | 788 | assert "_weight_int4pack_mm_for_cpu" in code[0]
|
785 | 789 | assert "aten.mm.default" not in code[0]
|
786 | 790 |
|
| 791 | + # TODO(#1690): move to new config names |
787 | 792 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
|
788 | 793 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
|
789 |
| - def test_int4_weight_only_numerics(self): |
| 794 | + @common_utils.parametrize( |
| 795 | + "config", |
| 796 | + [ |
| 797 | + int4_weight_only(), |
| 798 | + float8_weight_only(), |
| 799 | + float8_dynamic_activation_float8_weight(), |
| 800 | + float8_static_activation_float8_weight( |
| 801 | + scale=torch.tensor([1.0], device="cuda") |
| 802 | + ), |
| 803 | + ], |
| 804 | + ) |
| 805 | + def test_workflow_e2e_numerics(self, config): |
790 | 806 | """
|
791 | 807 | Simple test of e2e int4_weight_only workflow, comparing numerics
|
792 | 808 | to a bfloat16 baseline.
|
793 | 809 | """
|
| 810 | + if ( |
| 811 | + isinstance( |
| 812 | + config, |
| 813 | + ( |
| 814 | + float8_dynamic_activation_float8_weight, |
| 815 | + float8_static_activation_float8_weight, |
| 816 | + ), |
| 817 | + ) |
| 818 | + and not is_sm_at_least_89() |
| 819 | + ): |
| 820 | + return unittest.skip("requires CUDA capability 8.9 or greater") |
| 821 | + |
794 | 822 | # set up inputs
|
795 | 823 | x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
|
796 | 824 | # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
|
797 | 825 | # is that expected?
|
798 | 826 | m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
|
799 |
| - m_int4_wo = copy.deepcopy(m_ref) |
| 827 | + m_q = copy.deepcopy(m_ref) |
800 | 828 |
|
801 | 829 | # quantize
|
802 |
| - quantize_(m_int4_wo, int4_weight_only()) |
| 830 | + quantize_(m_q, config) |
803 | 831 |
|
804 | 832 | with torch.no_grad():
|
805 | 833 | y_ref = m_ref(x)
|
806 |
| - y_int4_wo = m_int4_wo(x) |
| 834 | + y_q = m_q(x) |
807 | 835 |
|
808 |
| - sqnr = compute_error(y_ref, y_int4_wo) |
| 836 | + sqnr = compute_error(y_ref, y_q) |
809 | 837 | assert sqnr >= 20, f"SQNR {sqnr} is too low"
|
810 | 838 |
|
811 | 839 |
|
|
0 commit comments