Skip to content

Commit df532f0

Browse files
authored
Fix slice and padding for TensorCoreTiledLayout (#2015)
* Fix slice and padding for TensorCoreTiledLayout for int4 weight only quantization Summary: Previously some of the code paths are not exercised, so the bug was not discovered but there are some bug related to slice operation and padding, basically scale and zero_point are not padded before, this results in errors when it is required. Test Plan: python test/dtypes/test_affine_quantized.py -k test_slice Reviewers: Subscribers: Tasks: Tags: * skip if no cuda * update callsites for post_process * add back missing post process * adding missing arg for floatx
1 parent 5ded23c commit df532f0

10 files changed

+126
-41
lines changed

test/dtypes/test_affine_quantized.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchao.core.config import AOBaseConfig
1818
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
1919
from torchao.quantization import (
20+
Int4WeightOnlyConfig,
2021
Int8DynamicActivationInt8WeightConfig,
2122
float8_weight_only,
2223
int4_dynamic_activation_int4_weight,
@@ -27,7 +28,7 @@
2728
quantize_,
2829
)
2930
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
30-
from torchao.testing.utils import skip_if_rocm
31+
from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm
3132
from torchao.utils import (
3233
TORCH_VERSION_AT_LEAST_2_5,
3334
TORCH_VERSION_AT_LEAST_2_6,
@@ -307,6 +308,19 @@ def test_alias(self, device, dtype):
307308
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
308309
_ = dummy.weight[...]
309310

311+
@common_utils.parametrize("device", ["cuda"])
312+
@common_utils.parametrize("dtype", [torch.bfloat16])
313+
@skip_if_no_cuda()
314+
def test_slice(self, device, dtype):
315+
# in_feature not divisible by 1024
316+
# out_feature not divisible by 8
317+
# to test slice + padding for int4 weight only quantization
318+
dummy = nn.Linear(256, 321, dtype=dtype, device=device)
319+
quantize_(dummy, Int4WeightOnlyConfig())
320+
# make sure these run without error
321+
_ = dummy.weight.narrow(0, 0, 64)
322+
_ = dummy.weight.narrow(1, 0, 128)
323+
310324

311325
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
312326
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

torchao/dtypes/affine_quantized_tensor.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ def from_hp_to_intx(
284284
)
285285
# Note: output will be uint8 tensor for sub byte tensors for now
286286

287-
data = _layout.post_process(data)
287+
data, scale, zero_point = _layout.post_process(
288+
data, scale, zero_point, block_size
289+
)
288290
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
289291
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
290292
return cls(
@@ -335,7 +337,7 @@ def from_hp_to_intx_static(
335337
zero_point_domain,
336338
)
337339

338-
int_data = _layout.post_process(int_data)
340+
int_data, scale, zero_point = _layout.post_process(int_data, scale, zero_point)
339341

340342
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
341343
tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, _layout)
@@ -429,7 +431,9 @@ def from_hp_to_fpx(
429431
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
430432
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
431433
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
432-
floatx_packed = _layout.post_process(floatx_unpacked)
434+
floatx_packed, scale, _ = _layout.post_process(
435+
floatx_unpacked, scale, None, block_size
436+
)
433437

434438
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
435439
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)

torchao/dtypes/uintx/gemlite_layout.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -297,14 +297,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
297297
if func is aten.slice.Tensor:
298298
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
299299
if dim == 0:
300+
assert step == 1, "Only step == 1 is supported in slicing right now"
300301
int_data, scale, zero_point = self.get_plain()
302+
data_len = int_data.shape[dim]
303+
param_dim = 1 - dim
304+
scale_len = scale.shape[param_dim]
305+
ratio = data_len / scale_len
306+
start_scale = int(start / ratio)
307+
end_scale = int(end / ratio)
308+
301309
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
302-
int_data = self._layout.post_process(int_data)
310+
scale = aten.slice.Tensor(
311+
scale, param_dim, start_scale, end_scale, step
312+
)
313+
if zero_point is not None and zero_point.numel() > 0:
314+
zero_point = aten.slice.Tensor(
315+
zero_point, param_dim, start_scale, end_scale, step
316+
)
317+
else:
318+
zero_point = None
319+
303320
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
304321
return return_and_correct_aliasing(func, args, kwargs, sliced)
305322
elif dim == 1:
306-
int_data, scale, zero_point = self.get_plain()
307323
assert step == 1, "Only step == 1 is supported in slicing right now"
324+
int_data, scale, zero_point = self.get_plain()
308325
data_len = int_data.shape[dim]
309326
# scale and zero_point are transposed compared to int_data
310327
param_dim = 1 - dim
@@ -314,7 +331,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
314331
end_scale = int(end / ratio)
315332

316333
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
317-
# this is to handle padding
318334
scale = aten.slice.Tensor(
319335
scale, param_dim, start_scale, end_scale, step
320336
)
@@ -324,9 +340,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
324340
)
325341
else:
326342
zero_point = None
327-
# import fbvscode; fbvscode.set_trace()
328343
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
329-
return sliced
344+
return return_and_correct_aliasing(func, args, kwargs, sliced)
330345
else:
331346
raise NotImplementedError(
332347
f"GemliteAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"

torchao/dtypes/uintx/int4_cpu_layout.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -192,31 +192,26 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
192192

193193
if func is aten.slice.Tensor:
194194
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
195-
if dim == 0:
196-
int_data, scale, zero_point = self.get_plain()
197-
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
198-
# this is to handle padding
199-
int_data = self._layout.post_process(int_data)
200-
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
201-
return return_and_correct_aliasing(func, args, kwargs, sliced)
202-
elif dim == 1:
203-
int_data, scale, zero_point = self.get_plain()
195+
if dim in [0, 1]:
204196
assert step == 1, "Only step == 1 is supported in slicing right now"
197+
int_data, scale, zero_point = self.get_plain()
205198
data_len = int_data.shape[dim]
206199
scale_len = scale.shape[dim]
207200
ratio = data_len / scale_len
208201
start_scale = int(start / ratio)
209202
end_scale = int(end / ratio)
210203

211204
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
212-
# this is to handle padding
213-
int_data = self._layout.post_process(int_data)
214205
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
215206
zero_point = aten.slice.Tensor(
216207
zero_point, dim, start_scale, end_scale, step
217208
)
209+
# this is to handle padding
210+
int_data, scale, zero_point = self._layout.post_process(
211+
int_data, scale, zero_point, self.block_size
212+
)
218213
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
219-
return sliced
214+
return return_and_correct_aliasing(func, args, kwargs, sliced)
220215
else:
221216
raise NotImplementedError(
222217
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
@@ -228,6 +223,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
228223

229224
__torch_function__ = torch._C._disabled_torch_function_impl
230225

226+
@property
227+
def block_size(self):
228+
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
229+
230+
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
231+
cur_shape = self.shape
232+
assert len(cur_shape) == 4
233+
inner_k_tiles = cur_shape[-1] * 2
234+
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
235+
groupsize = int(original_shape[1] / scale.shape[-2])
236+
return (1, groupsize)
237+
231238
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
232239
from torchao.quantization.quant_primitives import (
233240
ZeroPointDomain,

torchao/dtypes/uintx/marlin_qqq_tensor.py

-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def from_hp_to_intx(
7272
data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
7373
input_float, nbits, group_size
7474
)
75-
data = _layout.post_process(data)
7675
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
7776
tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout)
7877
return cls(

torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py

-1
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ def from_hp_to_intx(
429429
)
430430
# Note: output will be uint8 tensor for sub byte tensors for now
431431

432-
data = _layout.post_process(data)
433432
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
434433
tensor_impl = tensor_impl_ctr(
435434
data, scale, zero_point, _layout, **(tensor_impl_ctr_kwargs or {})

torchao/dtypes/uintx/tensor_core_tiled_layout.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,30 @@ def pre_process_static(
153153
zero_point = torch.nn.functional.pad(zero_point, padding_changes)
154154
return input, scale, zero_point
155155

156-
def post_process(self, input: torch.Tensor) -> torch.Tensor:
156+
def post_process(
157+
self,
158+
input: torch.Tensor,
159+
scale: torch.Tensor,
160+
zero_point: torch.Tensor,
161+
block_size: Tuple[int, ...],
162+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
157163
orig_out_features, orig_in_features = input.shape
158164
in_features = find_multiple(orig_in_features, 1024)
159165
out_features = find_multiple(orig_out_features, 8)
160166
input = torch.nn.functional.pad(
161167
input,
162168
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
163169
)
164-
return input
170+
assert (
171+
len(block_size) == 2
172+
), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: {block_size}"
173+
scale_pad_dim_0 = (out_features - orig_out_features) // block_size[0]
174+
scale_pad_dim_1 = (in_features - orig_in_features) // block_size[1]
175+
scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0))
176+
zero_point = torch.nn.functional.pad(
177+
zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0)
178+
)
179+
return input, scale, zero_point
165180

166181
def extra_repr(self):
167182
return f"inner_k_tiles={self.inner_k_tiles}"
@@ -335,31 +350,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
335350

336351
if func is aten.slice.Tensor:
337352
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
338-
if dim == 0:
339-
int_data, scale, zero_point = self.get_plain()
340-
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
341-
# this is to handle padding
342-
int_data = self._layout.post_process(int_data)
343-
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
344-
return return_and_correct_aliasing(func, args, kwargs, sliced)
345-
elif dim == 1:
353+
if dim in [0, 1]:
346354
int_data, scale, zero_point = self.get_plain()
347-
assert step == 1, "Only step == 1 is supported in slicing right now"
348355
data_len = int_data.shape[dim]
349356
scale_len = scale.shape[dim]
350357
ratio = data_len / scale_len
351358
start_scale = int(start / ratio)
352359
end_scale = int(end / ratio)
353360

354361
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
355-
# this is to handle padding
356-
int_data = self._layout.post_process(int_data)
357362
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
358363
zero_point = aten.slice.Tensor(
359364
zero_point, dim, start_scale, end_scale, step
360365
)
366+
# this is to handle padding
367+
int_data, scale, zero_point = self._layout.post_process(
368+
int_data, scale, zero_point, self.block_size
369+
)
361370
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
362-
return sliced
371+
return return_and_correct_aliasing(func, args, kwargs, sliced)
363372
else:
364373
raise NotImplementedError(
365374
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
@@ -371,6 +380,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
371380

372381
__torch_function__ = torch._C._disabled_torch_function_impl
373382

383+
@property
384+
def block_size(self):
385+
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
386+
387+
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
388+
cur_shape = self.shape
389+
assert len(cur_shape) == 4
390+
inner_k_tiles = cur_shape[-1] * 2
391+
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
392+
groupsize = int(original_shape[1] / scale.shape[-2])
393+
return (1, groupsize)
394+
374395
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
375396
from torchao.quantization.quant_primitives import (
376397
ZeroPointDomain,

torchao/dtypes/uintx/uintx_layout.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,14 @@ class UintxLayout(Layout):
228228
dtype: torch.dtype
229229
pack_dim: int = -1
230230

231-
def post_process(self, input: torch.Tensor) -> torch.Tensor:
232-
return to_uintx(input, self.dtype, self.pack_dim)
231+
def post_process(
232+
self,
233+
input: torch.Tensor,
234+
scale: torch.Tensor,
235+
zero_point: torch.Tensor,
236+
block_size: Tuple[int, ...],
237+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
238+
return to_uintx(input, self.dtype, self.pack_dim), scale, zero_point
233239

234240

235241
@register_layout(UintxLayout)

torchao/dtypes/utils.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,14 @@ class Layout:
4444
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
4545
return input
4646

47-
def post_process(self, input: torch.Tensor) -> torch.Tensor:
48-
return input
47+
def post_process(
48+
self,
49+
input: torch.Tensor,
50+
scale: torch.Tensor,
51+
zero_point: torch.Tensor,
52+
block_size: Tuple[int, ...],
53+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
54+
return input, scale, zero_point
4955

5056
def pre_process_static(
5157
self,

torchao/testing/utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,20 @@ def wrapper(*args, **kwargs):
9090
return decorator
9191

9292

93+
def skip_if_no_cuda():
94+
import unittest
95+
96+
def decorator(test_func):
97+
def wrapper(*args, **kwargs):
98+
if not torch.cuda.is_available():
99+
raise unittest.SkipTest("No cuda available")
100+
return test_func(*args, **kwargs)
101+
102+
return wrapper
103+
104+
return decorator
105+
106+
93107
# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389
94108
def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902
95109
for name, value in my_cls.__dict__.items():

0 commit comments

Comments
 (0)