@@ -153,15 +153,30 @@ def pre_process_static(
153
153
zero_point = torch .nn .functional .pad (zero_point , padding_changes )
154
154
return input , scale , zero_point
155
155
156
- def post_process (self , input : torch .Tensor ) -> torch .Tensor :
156
+ def post_process (
157
+ self ,
158
+ input : torch .Tensor ,
159
+ scale : torch .Tensor ,
160
+ zero_point : torch .Tensor ,
161
+ block_size : Tuple [int , ...],
162
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
157
163
orig_out_features , orig_in_features = input .shape
158
164
in_features = find_multiple (orig_in_features , 1024 )
159
165
out_features = find_multiple (orig_out_features , 8 )
160
166
input = torch .nn .functional .pad (
161
167
input ,
162
168
(0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
163
169
)
164
- return input
170
+ assert (
171
+ len (block_size ) == 2
172
+ ), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: { block_size } "
173
+ scale_pad_dim_0 = (out_features - orig_out_features ) // block_size [0 ]
174
+ scale_pad_dim_1 = (in_features - orig_in_features ) // block_size [1 ]
175
+ scale = torch .nn .functional .pad (scale , (0 , scale_pad_dim_1 , 0 , scale_pad_dim_0 ))
176
+ zero_point = torch .nn .functional .pad (
177
+ zero_point , (0 , scale_pad_dim_1 , 0 , scale_pad_dim_0 )
178
+ )
179
+ return input , scale , zero_point
165
180
166
181
def extra_repr (self ):
167
182
return f"inner_k_tiles={ self .inner_k_tiles } "
@@ -335,31 +350,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
335
350
336
351
if func is aten .slice .Tensor :
337
352
self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
338
- if dim == 0 :
339
- int_data , scale , zero_point = self .get_plain ()
340
- int_data = aten .slice .Tensor (int_data , dim , start , end , step )
341
- # this is to handle padding
342
- int_data = self ._layout .post_process (int_data )
343
- sliced = self .from_plain (int_data , scale , zero_point , self ._layout )
344
- return return_and_correct_aliasing (func , args , kwargs , sliced )
345
- elif dim == 1 :
353
+ if dim in [0 , 1 ]:
346
354
int_data , scale , zero_point = self .get_plain ()
347
- assert step == 1 , "Only step == 1 is supported in slicing right now"
348
355
data_len = int_data .shape [dim ]
349
356
scale_len = scale .shape [dim ]
350
357
ratio = data_len / scale_len
351
358
start_scale = int (start / ratio )
352
359
end_scale = int (end / ratio )
353
360
354
361
int_data = aten .slice .Tensor (int_data , dim , start , end , step )
355
- # this is to handle padding
356
- int_data = self ._layout .post_process (int_data )
357
362
scale = aten .slice .Tensor (scale , dim , start_scale , end_scale , step )
358
363
zero_point = aten .slice .Tensor (
359
364
zero_point , dim , start_scale , end_scale , step
360
365
)
366
+ # this is to handle padding
367
+ int_data , scale , zero_point = self ._layout .post_process (
368
+ int_data , scale , zero_point , self .block_size
369
+ )
361
370
sliced = self .from_plain (int_data , scale , zero_point , self ._layout )
362
- return sliced
371
+ return return_and_correct_aliasing ( func , args , kwargs , sliced )
363
372
else :
364
373
raise NotImplementedError (
365
374
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run { func } , with dim={ dim } , that is not supported"
@@ -371,6 +380,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
371
380
372
381
__torch_function__ = torch ._C ._disabled_torch_function_impl
373
382
383
+ @property
384
+ def block_size (self ):
385
+ from torchao .quantization .utils import unpack_tinygemm_scales_and_zeros
386
+
387
+ scale , zero = unpack_tinygemm_scales_and_zeros (self .scale_and_zero )
388
+ cur_shape = self .shape
389
+ assert len (cur_shape ) == 4
390
+ inner_k_tiles = cur_shape [- 1 ] * 2
391
+ original_shape = (cur_shape [0 ] * 8 , cur_shape [1 ] * (inner_k_tiles * 16 ))
392
+ groupsize = int (original_shape [1 ] / scale .shape [- 2 ])
393
+ return (1 , groupsize )
394
+
374
395
def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
375
396
from torchao .quantization .quant_primitives import (
376
397
ZeroPointDomain ,
0 commit comments