18
18
addmm_float8_unwrapped_inference ,
19
19
preprocess_data ,
20
20
)
21
+ from torchao .quantization .quant_primitives import (
22
+ FP8_TYPES ,
23
+ choose_qparams_affine_float8 ,
24
+ dequantize_affine_float8 ,
25
+ quantize_affine_float8 ,
26
+ )
21
27
from torchao .utils import _is_float8_type , fill_defaults
22
28
23
29
aten = torch .ops .aten
@@ -209,19 +215,64 @@ def __repr__(self):
209
215
)
210
216
211
217
218
+ class Float8QuantizedTensor (AffineQuantizedTensor ):
219
+ """
220
+ Float8 quantized tensor subclass which inherits Float8QuantizedTensor class.
221
+ """
222
+
223
+ def dequantize (self , output_dtype : Optional [torch .dtype ] = None ) -> torch .Tensor :
224
+ if output_dtype is None :
225
+ output_dtype = self .dtype
226
+ int_data , scale , _ = self .tensor_impl .get_plain ()
227
+ return dequantize_affine_float8 (
228
+ int_data ,
229
+ scale ,
230
+ output_dtype = output_dtype ,
231
+ )
232
+
233
+ @classmethod
234
+ def from_hp_to_float8 (
235
+ cls ,
236
+ input_float : torch .Tensor ,
237
+ target_dtype : torch .dtype ,
238
+ block_size : Tuple [int , ...],
239
+ _layout : Layout = Float8Layout (),
240
+ ):
241
+ assert target_dtype in FP8_TYPES , f"Unsupported dtype { target_dtype } for float8"
242
+ original_shape = input_float .shape
243
+ scale = choose_qparams_affine_float8 (
244
+ input_float ,
245
+ target_dtype ,
246
+ )
247
+ fp8_data = quantize_affine_float8 (
248
+ input_float ,
249
+ scale ,
250
+ target_dtype ,
251
+ )
252
+ fp8_data = _layout .post_process (fp8_data )
253
+ tensor_impl_ctr = cls .get_tensor_impl_constructor (type (_layout ))
254
+ tensor_impl = tensor_impl_ctr (fp8_data , scale , None , _layout )
255
+ return cls (
256
+ tensor_impl ,
257
+ block_size ,
258
+ original_shape ,
259
+ dtype = input_float .dtype ,
260
+ )
261
+
262
+
212
263
##########################
213
264
# Float8 Dispatch Kernels
214
265
##########################
215
266
216
267
217
268
def _linear_fp8_act_fp8_weight_check (
218
- input_tensor : Union [torch .Tensor , "AffineQuantizedTensor " ],
219
- weight_tensor : Union [torch .Tensor , "AffineQuantizedTensor " ],
269
+ input_tensor : Union [torch .Tensor , "Float8QuantizedTensor " ],
270
+ weight_tensor : Union [torch .Tensor , "Float8QuantizedTensor " ],
220
271
bias : Optional [torch .Tensor ],
221
272
) -> bool :
222
- def check_aqt (aqt : Union [torch .Tensor , AffineQuantizedTensor ]) -> bool :
273
+ def check_aqt (aqt : Union [torch .Tensor , Float8QuantizedTensor ]) -> bool :
223
274
return (
224
- isinstance (aqt , AffineQuantizedTensor )
275
+ isinstance (aqt , Float8QuantizedTensor )
225
276
and isinstance (aqt ._layout , Float8Layout )
226
277
and aqt .tensor_impl .dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]
227
278
and (aqt .shape == aqt .block_size or _is_rowwise_scaled (aqt ))
@@ -241,8 +292,8 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):
241
292
242
293
243
294
def _linear_fp8_act_fp8_weight_impl (
244
- input_tensor : "AffineQuantizedTensor " ,
245
- weight_tensor : "AffineQuantizedTensor " ,
295
+ input_tensor : "Float8QuantizedTensor " ,
296
+ weight_tensor : "Float8QuantizedTensor " ,
246
297
bias : Optional [torch .Tensor ],
247
298
):
248
299
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
@@ -285,8 +336,8 @@ def _linear_fp8_act_fp8_weight_impl(
285
336
286
337
287
338
def _linear_fp_act_fp8_weight_check (
288
- input_tensor : Union [torch .Tensor , "AffineQuantizedTensor " ],
289
- weight_tensor : Union [torch .Tensor , "AffineQuantizedTensor " ],
339
+ input_tensor : Union [torch .Tensor , "Float8QuantizedTensor " ],
340
+ weight_tensor : Union [torch .Tensor , "Float8QuantizedTensor " ],
290
341
bias : Optional [torch .Tensor ],
291
342
) -> bool :
292
343
return (
@@ -295,7 +346,7 @@ def _linear_fp_act_fp8_weight_check(
295
346
and input_tensor .is_floating_point ()
296
347
and
297
348
# weight is float8 quantized affine quantized tensor
298
- isinstance (weight_tensor , AffineQuantizedTensor )
349
+ isinstance (weight_tensor , Float8QuantizedTensor )
299
350
and isinstance (weight_tensor ._layout , Float8Layout )
300
351
and weight_tensor .tensor_impl .dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]
301
352
and (
@@ -307,7 +358,10 @@ def _linear_fp_act_fp8_weight_check(
307
358
308
359
def _linear_fp_act_fp8_weight_impl (
309
360
input_tensor : torch .Tensor ,
310
- weight_tensor : "AffineQuantizedTensor " ,
361
+ weight_tensor : "Float8QuantizedTensor " ,
311
362
bias : Optional [torch .Tensor ],
312
363
):
313
364
return torch .nn .functional .linear (input_tensor , weight_tensor .dequantize (), bias )
365
+
366
+
367
+ to_affine_quantized_float8 = Float8QuantizedTensor .from_hp_to_float8
0 commit comments