|
10 | 10 | MappingType,
|
11 | 11 | ZeroPointDomain,
|
12 | 12 | choose_qparams_affine,
|
13 |
| - choose_qparams_affine_float8, |
14 | 13 | choose_qparams_affine_floatx,
|
15 | 14 | choose_qparams_and_quantize_affine_hqq,
|
16 | 15 | dequantize_affine,
|
17 |
| - dequantize_affine_floatx, |
18 | 16 | quantize_affine,
|
19 |
| - quantize_affine_float8, |
20 | 17 | quantize_affine_floatx,
|
21 | 18 | )
|
22 | 19 | from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
|
|
28 | 25 | "AffineQuantizedTensor",
|
29 | 26 | "register_layout",
|
30 | 27 | "to_affine_quantized_intx",
|
31 |
| - "to_affine_quantized_floatx", |
32 | 28 | "to_affine_quantized_intx_static",
|
33 | 29 | "to_affine_quantized_floatx_static",
|
34 | 30 | "to_affine_quantized_fpx",
|
@@ -121,40 +117,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
|
121 | 117 | if output_dtype is None:
|
122 | 118 | output_dtype = self.dtype
|
123 | 119 |
|
124 |
| - from torchao.dtypes.floatx import FloatxTensorCoreLayout |
125 |
| - |
126 |
| - if isinstance(self._layout, FloatxTensorCoreLayout): |
127 |
| - int_data, scale = self.tensor_impl.get_plain() |
128 |
| - return dequantize_affine_floatx( |
129 |
| - int_data, |
130 |
| - scale, |
131 |
| - self._layout.ebits, |
132 |
| - self._layout.mbits, |
133 |
| - output_dtype=output_dtype, |
134 |
| - ) |
135 |
| - else: |
136 |
| - data, scale, zero_point = self.tensor_impl.get_plain() |
137 |
| - dq = dequantize_affine( |
138 |
| - data, |
139 |
| - self.block_size, |
140 |
| - scale, |
141 |
| - zero_point, |
142 |
| - data.dtype, |
143 |
| - self.quant_min, |
144 |
| - self.quant_max, |
145 |
| - self.zero_point_domain, |
146 |
| - output_dtype=output_dtype, |
147 |
| - ) |
148 |
| - from torchao.dtypes.uintx import TensorCoreTiledLayout |
| 120 | + data, scale, zero_point = self.tensor_impl.get_plain() |
| 121 | + dq = dequantize_affine( |
| 122 | + data, |
| 123 | + self.block_size, |
| 124 | + scale, |
| 125 | + zero_point, |
| 126 | + data.dtype, |
| 127 | + self.quant_min, |
| 128 | + self.quant_max, |
| 129 | + self.zero_point_domain, |
| 130 | + output_dtype=output_dtype, |
| 131 | + ) |
| 132 | + from torchao.dtypes.uintx import TensorCoreTiledLayout |
149 | 133 |
|
150 |
| - if isinstance(self._layout, TensorCoreTiledLayout): |
151 |
| - # need to return to original shape if tensor was padded |
152 |
| - # in preprocessing |
153 |
| - # TODO: we could add an API for this if there are more use cases |
154 |
| - # (e.g. dequant_post_process) in TensorImpl or Layout |
155 |
| - for dim, dim_size in enumerate(self.shape): |
156 |
| - dq = dq.narrow(dim, 0, dim_size) |
157 |
| - return dq |
| 134 | + if isinstance(self._layout, TensorCoreTiledLayout): |
| 135 | + # need to return to original shape if tensor was padded |
| 136 | + # in preprocessing |
| 137 | + # TODO: we could add an API for this if there are more use cases |
| 138 | + # (e.g. dequant_post_process) in TensorImpl or Layout |
| 139 | + for dim, dim_size in enumerate(self.shape): |
| 140 | + dq = dq.narrow(dim, 0, dim_size) |
| 141 | + return dq |
158 | 142 |
|
159 | 143 | def __tensor_flatten__(self):
|
160 | 144 | return ["tensor_impl"], [
|
@@ -272,7 +256,7 @@ def from_hp_to_intx(
|
272 | 256 | # Note: output will be uint8 tensor for sub byte tensors for now
|
273 | 257 |
|
274 | 258 | data = _layout.post_process(data)
|
275 |
| - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) |
| 259 | + tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout)) |
276 | 260 | tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
|
277 | 261 | return cls(
|
278 | 262 | tensor_impl,
|
@@ -417,36 +401,6 @@ def from_hp_to_fpx(
|
417 | 401 | tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
|
418 | 402 | return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)
|
419 | 403 |
|
420 |
| - @classmethod |
421 |
| - def from_hp_to_float8( |
422 |
| - cls, |
423 |
| - input_float: torch.Tensor, |
424 |
| - target_dtype: torch.dtype, |
425 |
| - block_size: Tuple[int, ...], |
426 |
| - _layout: Layout = PlainLayout(), |
427 |
| - ): |
428 |
| - assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8" |
429 |
| - original_shape = input_float.shape |
430 |
| - scale = choose_qparams_affine_float8( |
431 |
| - input_float, |
432 |
| - target_dtype, |
433 |
| - target_dtype, |
434 |
| - ) |
435 |
| - fp8_data = quantize_affine_float8( |
436 |
| - input_float, |
437 |
| - scale, |
438 |
| - target_dtype, |
439 |
| - ) |
440 |
| - fp8_data = _layout.post_process(fp8_data) |
441 |
| - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) |
442 |
| - tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout) |
443 |
| - return cls( |
444 |
| - tensor_impl, |
445 |
| - block_size, |
446 |
| - original_shape, |
447 |
| - dtype=input_float.dtype, |
448 |
| - ) |
449 |
| - |
450 | 404 | @property
|
451 | 405 | def _layout(self) -> Layout:
|
452 | 406 | return self.tensor_impl._layout
|
@@ -500,9 +454,7 @@ def _apply_fn_to_data(self, fn):
|
500 | 454 |
|
501 | 455 | to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
|
502 | 456 | to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
|
503 |
| -to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx |
504 | 457 | to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
|
505 |
| -to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8 |
506 | 458 | # experimental will be merged in to floatx
|
507 | 459 | to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
|
508 | 460 |
|
|
0 commit comments