Skip to content

Commit cf45336

Browse files
authored
Relax dtype requirements for int4 and float8 quants in autoquant (#1571)
* Relax dtype requirements for int4 quants in autoquant Summary: Some of the int4 quant only works with bfloat16/float16, previously we require the model to be in correct dtype to apply these in autoquant, this PR relaxes the constraints by converting weight and activation to compatible dtypes Test Plan: python test/integration/test_integration.py -k test_autoquant_int4wo Reviewers: Subscribers: Tasks: Tags: * remove prints * add float8 * run pre-commit * run pre-commit * manual format * enable bias=True test * remove print
1 parent f520c91 commit cf45336

File tree

3 files changed

+207
-69
lines changed

3 files changed

+207
-69
lines changed

test/integration/test_integration.py

+100-25
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
2626
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
2727
AQFloat8WeightOnlyQuantizedLinearWeight,
28+
AQGemliteInt4G64WeightOnlyQuantizedLinearWeight,
29+
AQInt4G32WeightOnlyQuantizedLinearWeight,
30+
AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight,
2831
AQInt8DynamicallyQuantizedLinearWeight,
2932
AQInt8WeightOnlyQuantizedLinearWeight,
3033
AQInt8WeightOnlyQuantizedLinearWeight2,
@@ -1751,37 +1754,109 @@ def test_autoquant_min_sqnr(self, device, dtype):
17511754
@unittest.skipIf(
17521755
not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+."
17531756
)
1754-
def test_autoquant_float(self):
1757+
def test_autoquant_hp_float(self):
17551758
device = "cuda"
17561759
dtype = torch.float32
17571760
m, k, n = 128, 128, 128
17581761
example_input = torch.randn(m, k, device=device, dtype=dtype)
1759-
model = (
1760-
torch.nn.Sequential(
1761-
torch.nn.ReLU(),
1762-
torch.nn.Linear(k, n),
1763-
torch.nn.ReLU(),
1762+
for qclass in torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST:
1763+
model = (
1764+
torch.nn.Sequential(
1765+
torch.nn.ReLU(),
1766+
torch.nn.Linear(k, n, bias=True),
1767+
torch.nn.ReLU(),
1768+
)
1769+
.to(device)
1770+
.to(dtype)
17641771
)
1765-
.to(device)
1766-
.to(dtype)
1767-
)
1768-
ref = model(example_input)
1769-
torchao.autoquant(
1770-
model,
1771-
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
1772-
)
1773-
out = model(example_input)
1774-
from torchao.quantization.autoquant import (
1775-
BFloat16Tensor,
1776-
Float16Tensor,
1777-
Float32Tensor,
1778-
)
1772+
ref = model(example_input)
1773+
qtensor_class_list = [qclass]
1774+
torchao.autoquant(
1775+
model,
1776+
qtensor_class_list=qtensor_class_list,
1777+
)
1778+
out = model(example_input)
1779+
self.assertIn(
1780+
type(model[1].weight),
1781+
qtensor_class_list,
1782+
)
1783+
self.assertGreater(compute_error(out, ref), 40)
17791784

1780-
self.assertIn(
1781-
type(model[1].weight), [Float32Tensor, Float16Tensor, BFloat16Tensor]
1782-
)
1783-
print(compute_error(out, ref))
1784-
self.assertGreater(compute_error(out, ref), 60)
1785+
@parameterized.expand(COMMON_DEVICE_DTYPE)
1786+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1787+
@unittest.skipIf(
1788+
not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+."
1789+
)
1790+
@unittest.skipIf(not has_gemlite, "gemlite not available")
1791+
def test_autoquant_int4wo(self, device, dtype):
1792+
if device == "cpu":
1793+
self.skipTest(f"int4wo is for cuda, not {device}")
1794+
1795+
m, k, n = 128, 128, 128
1796+
example_input = torch.randn(m, k, device=device, dtype=dtype)
1797+
1798+
for qclass in [
1799+
AQGemliteInt4G64WeightOnlyQuantizedLinearWeight,
1800+
AQInt4G32WeightOnlyQuantizedLinearWeight,
1801+
AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight,
1802+
]:
1803+
model = (
1804+
torch.nn.Sequential(
1805+
torch.nn.ReLU(),
1806+
torch.nn.Linear(k, n, bias=True),
1807+
torch.nn.ReLU(),
1808+
)
1809+
.to(device)
1810+
.to(dtype)
1811+
)
1812+
ref = model(example_input)
1813+
qtensor_class_list = [qclass]
1814+
torchao.autoquant(
1815+
model,
1816+
qtensor_class_list=qtensor_class_list,
1817+
)
1818+
out = model(example_input)
1819+
1820+
self.assertIn(type(model[1].weight), qtensor_class_list)
1821+
self.assertGreater(compute_error(ref, out), 20)
1822+
1823+
@parameterized.expand(COMMON_DEVICE_DTYPE)
1824+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1825+
@unittest.skipIf(
1826+
not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+."
1827+
)
1828+
def test_autoquant_float8(self, device, dtype):
1829+
if device == "cpu":
1830+
self.skipTest(f"int4wo is for cuda, not {device}")
1831+
1832+
# note: marlin sparse layout failed when scale_t has a dimension of 1d
1833+
m, k, n = 128, 128, 128
1834+
example_input = torch.randn(m, k, device=device, dtype=dtype)
1835+
1836+
for qclass in [
1837+
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
1838+
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
1839+
AQFloat8WeightOnlyQuantizedLinearWeight,
1840+
]:
1841+
model = (
1842+
torch.nn.Sequential(
1843+
torch.nn.ReLU(),
1844+
torch.nn.Linear(k, n, bias=True),
1845+
torch.nn.ReLU(),
1846+
)
1847+
.to(device)
1848+
.to(dtype)
1849+
)
1850+
ref = model(example_input)
1851+
qtensor_class_list = [qclass]
1852+
torchao.autoquant(
1853+
model,
1854+
qtensor_class_list=qtensor_class_list,
1855+
)
1856+
out = model(example_input)
1857+
1858+
self.assertIn(type(model[1].weight), qtensor_class_list)
1859+
self.assertGreater(compute_error(ref, out), 20)
17851860

17861861

17871862
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")

torchao/dtypes/uintx/marlin_sparse_layout.py

+5
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ def from_plain(
227227
# Linear layers are (in_features, out_features) but the int_data that is reaching this point
228228
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
229229
q_w_24 = int_data.t()
230+
# addressing the case when scale has dimension 1, happens when
231+
# weight_shape[-1] == group_size == 128
232+
if scale.ndim == 1:
233+
scale = scale.reshape(scale.shape[0], -1)
234+
230235
scale_t = scale.t()
231236

232237
if not torch.cuda.get_device_capability()[0] >= 8:

0 commit comments

Comments
 (0)