|
25 | 25 | AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
|
26 | 26 | AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
|
27 | 27 | AQFloat8WeightOnlyQuantizedLinearWeight,
|
| 28 | + AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, |
| 29 | + AQInt4G32WeightOnlyQuantizedLinearWeight, |
| 30 | + AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, |
28 | 31 | AQInt8DynamicallyQuantizedLinearWeight,
|
29 | 32 | AQInt8WeightOnlyQuantizedLinearWeight,
|
30 | 33 | AQInt8WeightOnlyQuantizedLinearWeight2,
|
@@ -1751,37 +1754,109 @@ def test_autoquant_min_sqnr(self, device, dtype):
|
1751 | 1754 | @unittest.skipIf(
|
1752 | 1755 | not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+."
|
1753 | 1756 | )
|
1754 |
| - def test_autoquant_float(self): |
| 1757 | + def test_autoquant_hp_float(self): |
1755 | 1758 | device = "cuda"
|
1756 | 1759 | dtype = torch.float32
|
1757 | 1760 | m, k, n = 128, 128, 128
|
1758 | 1761 | example_input = torch.randn(m, k, device=device, dtype=dtype)
|
1759 |
| - model = ( |
1760 |
| - torch.nn.Sequential( |
1761 |
| - torch.nn.ReLU(), |
1762 |
| - torch.nn.Linear(k, n), |
1763 |
| - torch.nn.ReLU(), |
| 1762 | + for qclass in torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST: |
| 1763 | + model = ( |
| 1764 | + torch.nn.Sequential( |
| 1765 | + torch.nn.ReLU(), |
| 1766 | + torch.nn.Linear(k, n, bias=True), |
| 1767 | + torch.nn.ReLU(), |
| 1768 | + ) |
| 1769 | + .to(device) |
| 1770 | + .to(dtype) |
1764 | 1771 | )
|
1765 |
| - .to(device) |
1766 |
| - .to(dtype) |
1767 |
| - ) |
1768 |
| - ref = model(example_input) |
1769 |
| - torchao.autoquant( |
1770 |
| - model, |
1771 |
| - qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, |
1772 |
| - ) |
1773 |
| - out = model(example_input) |
1774 |
| - from torchao.quantization.autoquant import ( |
1775 |
| - BFloat16Tensor, |
1776 |
| - Float16Tensor, |
1777 |
| - Float32Tensor, |
1778 |
| - ) |
| 1772 | + ref = model(example_input) |
| 1773 | + qtensor_class_list = [qclass] |
| 1774 | + torchao.autoquant( |
| 1775 | + model, |
| 1776 | + qtensor_class_list=qtensor_class_list, |
| 1777 | + ) |
| 1778 | + out = model(example_input) |
| 1779 | + self.assertIn( |
| 1780 | + type(model[1].weight), |
| 1781 | + qtensor_class_list, |
| 1782 | + ) |
| 1783 | + self.assertGreater(compute_error(out, ref), 40) |
1779 | 1784 |
|
1780 |
| - self.assertIn( |
1781 |
| - type(model[1].weight), [Float32Tensor, Float16Tensor, BFloat16Tensor] |
1782 |
| - ) |
1783 |
| - print(compute_error(out, ref)) |
1784 |
| - self.assertGreater(compute_error(out, ref), 60) |
| 1785 | + @parameterized.expand(COMMON_DEVICE_DTYPE) |
| 1786 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 1787 | + @unittest.skipIf( |
| 1788 | + not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." |
| 1789 | + ) |
| 1790 | + @unittest.skipIf(not has_gemlite, "gemlite not available") |
| 1791 | + def test_autoquant_int4wo(self, device, dtype): |
| 1792 | + if device == "cpu": |
| 1793 | + self.skipTest(f"int4wo is for cuda, not {device}") |
| 1794 | + |
| 1795 | + m, k, n = 128, 128, 128 |
| 1796 | + example_input = torch.randn(m, k, device=device, dtype=dtype) |
| 1797 | + |
| 1798 | + for qclass in [ |
| 1799 | + AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, |
| 1800 | + AQInt4G32WeightOnlyQuantizedLinearWeight, |
| 1801 | + AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, |
| 1802 | + ]: |
| 1803 | + model = ( |
| 1804 | + torch.nn.Sequential( |
| 1805 | + torch.nn.ReLU(), |
| 1806 | + torch.nn.Linear(k, n, bias=True), |
| 1807 | + torch.nn.ReLU(), |
| 1808 | + ) |
| 1809 | + .to(device) |
| 1810 | + .to(dtype) |
| 1811 | + ) |
| 1812 | + ref = model(example_input) |
| 1813 | + qtensor_class_list = [qclass] |
| 1814 | + torchao.autoquant( |
| 1815 | + model, |
| 1816 | + qtensor_class_list=qtensor_class_list, |
| 1817 | + ) |
| 1818 | + out = model(example_input) |
| 1819 | + |
| 1820 | + self.assertIn(type(model[1].weight), qtensor_class_list) |
| 1821 | + self.assertGreater(compute_error(ref, out), 20) |
| 1822 | + |
| 1823 | + @parameterized.expand(COMMON_DEVICE_DTYPE) |
| 1824 | + @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") |
| 1825 | + @unittest.skipIf( |
| 1826 | + not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." |
| 1827 | + ) |
| 1828 | + def test_autoquant_float8(self, device, dtype): |
| 1829 | + if device == "cpu": |
| 1830 | + self.skipTest(f"int4wo is for cuda, not {device}") |
| 1831 | + |
| 1832 | + # note: marlin sparse layout failed when scale_t has a dimension of 1d |
| 1833 | + m, k, n = 128, 128, 128 |
| 1834 | + example_input = torch.randn(m, k, device=device, dtype=dtype) |
| 1835 | + |
| 1836 | + for qclass in [ |
| 1837 | + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, |
| 1838 | + AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, |
| 1839 | + AQFloat8WeightOnlyQuantizedLinearWeight, |
| 1840 | + ]: |
| 1841 | + model = ( |
| 1842 | + torch.nn.Sequential( |
| 1843 | + torch.nn.ReLU(), |
| 1844 | + torch.nn.Linear(k, n, bias=True), |
| 1845 | + torch.nn.ReLU(), |
| 1846 | + ) |
| 1847 | + .to(device) |
| 1848 | + .to(dtype) |
| 1849 | + ) |
| 1850 | + ref = model(example_input) |
| 1851 | + qtensor_class_list = [qclass] |
| 1852 | + torchao.autoquant( |
| 1853 | + model, |
| 1854 | + qtensor_class_list=qtensor_class_list, |
| 1855 | + ) |
| 1856 | + out = model(example_input) |
| 1857 | + |
| 1858 | + self.assertIn(type(model[1].weight), qtensor_class_list) |
| 1859 | + self.assertGreater(compute_error(ref, out), 20) |
1785 | 1860 |
|
1786 | 1861 |
|
1787 | 1862 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
|
|
0 commit comments