|  | 
| 25 | 25 |     AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, | 
| 26 | 26 |     AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, | 
| 27 | 27 |     AQFloat8WeightOnlyQuantizedLinearWeight, | 
|  | 28 | +    AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, | 
|  | 29 | +    AQInt4G32WeightOnlyQuantizedLinearWeight, | 
|  | 30 | +    AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, | 
| 28 | 31 |     AQInt8DynamicallyQuantizedLinearWeight, | 
| 29 | 32 |     AQInt8WeightOnlyQuantizedLinearWeight, | 
| 30 | 33 |     AQInt8WeightOnlyQuantizedLinearWeight2, | 
| @@ -1751,37 +1754,109 @@ def test_autoquant_min_sqnr(self, device, dtype): | 
| 1751 | 1754 |     @unittest.skipIf( | 
| 1752 | 1755 |         not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+." | 
| 1753 | 1756 |     ) | 
| 1754 |  | -    def test_autoquant_float(self): | 
|  | 1757 | +    def test_autoquant_hp_float(self): | 
| 1755 | 1758 |         device = "cuda" | 
| 1756 | 1759 |         dtype = torch.float32 | 
| 1757 | 1760 |         m, k, n = 128, 128, 128 | 
| 1758 | 1761 |         example_input = torch.randn(m, k, device=device, dtype=dtype) | 
| 1759 |  | -        model = ( | 
| 1760 |  | -            torch.nn.Sequential( | 
| 1761 |  | -                torch.nn.ReLU(), | 
| 1762 |  | -                torch.nn.Linear(k, n), | 
| 1763 |  | -                torch.nn.ReLU(), | 
|  | 1762 | +        for qclass in torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST: | 
|  | 1763 | +            model = ( | 
|  | 1764 | +                torch.nn.Sequential( | 
|  | 1765 | +                    torch.nn.ReLU(), | 
|  | 1766 | +                    torch.nn.Linear(k, n, bias=True), | 
|  | 1767 | +                    torch.nn.ReLU(), | 
|  | 1768 | +                ) | 
|  | 1769 | +                .to(device) | 
|  | 1770 | +                .to(dtype) | 
| 1764 | 1771 |             ) | 
| 1765 |  | -            .to(device) | 
| 1766 |  | -            .to(dtype) | 
| 1767 |  | -        ) | 
| 1768 |  | -        ref = model(example_input) | 
| 1769 |  | -        torchao.autoquant( | 
| 1770 |  | -            model, | 
| 1771 |  | -            qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, | 
| 1772 |  | -        ) | 
| 1773 |  | -        out = model(example_input) | 
| 1774 |  | -        from torchao.quantization.autoquant import ( | 
| 1775 |  | -            BFloat16Tensor, | 
| 1776 |  | -            Float16Tensor, | 
| 1777 |  | -            Float32Tensor, | 
| 1778 |  | -        ) | 
|  | 1772 | +            ref = model(example_input) | 
|  | 1773 | +            qtensor_class_list = [qclass] | 
|  | 1774 | +            torchao.autoquant( | 
|  | 1775 | +                model, | 
|  | 1776 | +                qtensor_class_list=qtensor_class_list, | 
|  | 1777 | +            ) | 
|  | 1778 | +            out = model(example_input) | 
|  | 1779 | +            self.assertIn( | 
|  | 1780 | +                type(model[1].weight), | 
|  | 1781 | +                qtensor_class_list, | 
|  | 1782 | +            ) | 
|  | 1783 | +            self.assertGreater(compute_error(out, ref), 40) | 
| 1779 | 1784 | 
 | 
| 1780 |  | -        self.assertIn( | 
| 1781 |  | -            type(model[1].weight), [Float32Tensor, Float16Tensor, BFloat16Tensor] | 
| 1782 |  | -        ) | 
| 1783 |  | -        print(compute_error(out, ref)) | 
| 1784 |  | -        self.assertGreater(compute_error(out, ref), 60) | 
|  | 1785 | +    @parameterized.expand(COMMON_DEVICE_DTYPE) | 
|  | 1786 | +    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | 
|  | 1787 | +    @unittest.skipIf( | 
|  | 1788 | +        not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." | 
|  | 1789 | +    ) | 
|  | 1790 | +    @unittest.skipIf(not has_gemlite, "gemlite not available") | 
|  | 1791 | +    def test_autoquant_int4wo(self, device, dtype): | 
|  | 1792 | +        if device == "cpu": | 
|  | 1793 | +            self.skipTest(f"int4wo is for cuda, not {device}") | 
|  | 1794 | + | 
|  | 1795 | +        m, k, n = 128, 128, 128 | 
|  | 1796 | +        example_input = torch.randn(m, k, device=device, dtype=dtype) | 
|  | 1797 | + | 
|  | 1798 | +        for qclass in [ | 
|  | 1799 | +            AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, | 
|  | 1800 | +            AQInt4G32WeightOnlyQuantizedLinearWeight, | 
|  | 1801 | +            AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, | 
|  | 1802 | +        ]: | 
|  | 1803 | +            model = ( | 
|  | 1804 | +                torch.nn.Sequential( | 
|  | 1805 | +                    torch.nn.ReLU(), | 
|  | 1806 | +                    torch.nn.Linear(k, n, bias=True), | 
|  | 1807 | +                    torch.nn.ReLU(), | 
|  | 1808 | +                ) | 
|  | 1809 | +                .to(device) | 
|  | 1810 | +                .to(dtype) | 
|  | 1811 | +            ) | 
|  | 1812 | +            ref = model(example_input) | 
|  | 1813 | +            qtensor_class_list = [qclass] | 
|  | 1814 | +            torchao.autoquant( | 
|  | 1815 | +                model, | 
|  | 1816 | +                qtensor_class_list=qtensor_class_list, | 
|  | 1817 | +            ) | 
|  | 1818 | +            out = model(example_input) | 
|  | 1819 | + | 
|  | 1820 | +            self.assertIn(type(model[1].weight), qtensor_class_list) | 
|  | 1821 | +            self.assertGreater(compute_error(ref, out), 20) | 
|  | 1822 | + | 
|  | 1823 | +    @parameterized.expand(COMMON_DEVICE_DTYPE) | 
|  | 1824 | +    @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") | 
|  | 1825 | +    @unittest.skipIf( | 
|  | 1826 | +        not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." | 
|  | 1827 | +    ) | 
|  | 1828 | +    def test_autoquant_float8(self, device, dtype): | 
|  | 1829 | +        if device == "cpu": | 
|  | 1830 | +            self.skipTest(f"int4wo is for cuda, not {device}") | 
|  | 1831 | + | 
|  | 1832 | +        # note: marlin sparse layout failed when scale_t has a dimension of 1d | 
|  | 1833 | +        m, k, n = 128, 128, 128 | 
|  | 1834 | +        example_input = torch.randn(m, k, device=device, dtype=dtype) | 
|  | 1835 | + | 
|  | 1836 | +        for qclass in [ | 
|  | 1837 | +            AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, | 
|  | 1838 | +            AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, | 
|  | 1839 | +            AQFloat8WeightOnlyQuantizedLinearWeight, | 
|  | 1840 | +        ]: | 
|  | 1841 | +            model = ( | 
|  | 1842 | +                torch.nn.Sequential( | 
|  | 1843 | +                    torch.nn.ReLU(), | 
|  | 1844 | +                    torch.nn.Linear(k, n, bias=True), | 
|  | 1845 | +                    torch.nn.ReLU(), | 
|  | 1846 | +                ) | 
|  | 1847 | +                .to(device) | 
|  | 1848 | +                .to(dtype) | 
|  | 1849 | +            ) | 
|  | 1850 | +            ref = model(example_input) | 
|  | 1851 | +            qtensor_class_list = [qclass] | 
|  | 1852 | +            torchao.autoquant( | 
|  | 1853 | +                model, | 
|  | 1854 | +                qtensor_class_list=qtensor_class_list, | 
|  | 1855 | +            ) | 
|  | 1856 | +            out = model(example_input) | 
|  | 1857 | + | 
|  | 1858 | +            self.assertIn(type(model[1].weight), qtensor_class_list) | 
|  | 1859 | +            self.assertGreater(compute_error(ref, out), 20) | 
| 1785 | 1860 | 
 | 
| 1786 | 1861 | 
 | 
| 1787 | 1862 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") | 
|  | 
0 commit comments