|
14 | 14 | MappingType,
|
15 | 15 | ZeroPointDomain,
|
16 | 16 | choose_qparams_affine,
|
17 |
| - choose_qparams_affine_floatx, |
18 | 17 | choose_qparams_and_quantize_affine_hqq,
|
19 | 18 | dequantize_affine,
|
20 |
| - dequantize_affine_floatx, |
21 | 19 | quantize_affine,
|
22 |
| - quantize_affine_floatx, |
23 | 20 | )
|
24 | 21 | from torchao.utils import (
|
25 | 22 | TORCH_VERSION_AT_LEAST_2_5,
|
|
36 | 33 | "to_affine_quantized_floatx",
|
37 | 34 | "to_affine_quantized_intx_static",
|
38 | 35 | "to_affine_quantized_floatx_static",
|
39 |
| - "to_affine_quantized_fpx", |
40 | 36 | ]
|
41 | 37 |
|
42 | 38 |
|
@@ -126,40 +122,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
|
126 | 122 | if output_dtype is None:
|
127 | 123 | output_dtype = self.dtype
|
128 | 124 |
|
129 |
| - from torchao.dtypes.floatx import FloatxTensorCoreLayout |
130 |
| - |
131 |
| - if isinstance(self._layout, FloatxTensorCoreLayout): |
132 |
| - int_data, scale = self.tensor_impl.get_plain() |
133 |
| - return dequantize_affine_floatx( |
134 |
| - int_data, |
135 |
| - scale, |
136 |
| - self._layout.ebits, |
137 |
| - self._layout.mbits, |
138 |
| - output_dtype=output_dtype, |
139 |
| - ) |
140 |
| - else: |
141 |
| - data, scale, zero_point = self.tensor_impl.get_plain() |
142 |
| - dq = dequantize_affine( |
143 |
| - data, |
144 |
| - self.block_size, |
145 |
| - scale, |
146 |
| - zero_point, |
147 |
| - data.dtype, |
148 |
| - self.quant_min, |
149 |
| - self.quant_max, |
150 |
| - self.zero_point_domain, |
151 |
| - output_dtype=output_dtype, |
152 |
| - ) |
153 |
| - from torchao.dtypes.uintx import TensorCoreTiledLayout |
| 125 | + data, scale, zero_point = self.tensor_impl.get_plain() |
| 126 | + dq = dequantize_affine( |
| 127 | + data, |
| 128 | + self.block_size, |
| 129 | + scale, |
| 130 | + zero_point, |
| 131 | + data.dtype, |
| 132 | + self.quant_min, |
| 133 | + self.quant_max, |
| 134 | + self.zero_point_domain, |
| 135 | + output_dtype=output_dtype, |
| 136 | + ) |
| 137 | + from torchao.dtypes.uintx import TensorCoreTiledLayout |
154 | 138 |
|
155 |
| - if isinstance(self._layout, TensorCoreTiledLayout): |
156 |
| - # need to return to original shape if tensor was padded |
157 |
| - # in preprocessing |
158 |
| - # TODO: we could add an API for this if there are more use cases |
159 |
| - # (e.g. dequant_post_process) in TensorImpl or Layout |
160 |
| - for dim, dim_size in enumerate(self.shape): |
161 |
| - dq = dq.narrow(dim, 0, dim_size) |
162 |
| - return dq |
| 139 | + if isinstance(self._layout, TensorCoreTiledLayout): |
| 140 | + # need to return to original shape if tensor was padded |
| 141 | + # in preprocessing |
| 142 | + # TODO: we could add an API for this if there are more use cases |
| 143 | + # (e.g. dequant_post_process) in TensorImpl or Layout |
| 144 | + for dim, dim_size in enumerate(self.shape): |
| 145 | + dq = dq.narrow(dim, 0, dim_size) |
| 146 | + return dq |
163 | 147 |
|
164 | 148 | def __tensor_flatten__(self):
|
165 | 149 | return ["tensor_impl"], [
|
@@ -395,33 +379,6 @@ def from_hp_to_floatx_static(
|
395 | 379 | f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static"
|
396 | 380 | )
|
397 | 381 |
|
398 |
| - @classmethod |
399 |
| - def from_hp_to_fpx( |
400 |
| - cls, |
401 |
| - input_float: torch.Tensor, |
402 |
| - _layout: Layout, |
403 |
| - ): |
404 |
| - from torchao.dtypes.floatx import FloatxTensorCoreLayout |
405 |
| - |
406 |
| - assert isinstance( |
407 |
| - _layout, FloatxTensorCoreLayout |
408 |
| - ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" |
409 |
| - original_shape = input_float.shape |
410 |
| - input_float = _layout.pre_process(input_float) |
411 |
| - # per axis quantization, where axis = 1 |
412 |
| - block_size = list(input_float.shape) |
413 |
| - block_size[1] = 1 |
414 |
| - |
415 |
| - ebits, mbits = _layout.ebits, _layout.mbits |
416 |
| - # Note: these ops are hardcoded to have per axis quantization (axis=1) right now |
417 |
| - scale = choose_qparams_affine_floatx(input_float, ebits, mbits) |
418 |
| - floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) |
419 |
| - floatx_packed = _layout.post_process(floatx_unpacked) |
420 |
| - |
421 |
| - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) |
422 |
| - tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) |
423 |
| - return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) |
424 |
| - |
425 | 382 | @property
|
426 | 383 | def _layout(self) -> Layout:
|
427 | 384 | return self.tensor_impl._layout
|
@@ -477,8 +434,6 @@ def _apply_fn_to_data(self, fn):
|
477 | 434 | to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
|
478 | 435 | to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
|
479 | 436 | to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
|
480 |
| -# experimental will be merged in to floatx |
481 |
| -to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx |
482 | 437 |
|
483 | 438 | if TORCH_VERSION_AT_LEAST_2_5:
|
484 | 439 | # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`
|
|
0 commit comments