|
14 | 14 | MappingType, |
15 | 15 | ZeroPointDomain, |
16 | 16 | choose_qparams_affine, |
17 | | - choose_qparams_affine_floatx, |
18 | 17 | choose_qparams_and_quantize_affine_hqq, |
19 | 18 | dequantize_affine, |
20 | | - dequantize_affine_floatx, |
21 | 19 | quantize_affine, |
22 | | - quantize_affine_floatx, |
23 | 20 | ) |
24 | 21 | from torchao.utils import ( |
25 | 22 | TORCH_VERSION_AT_LEAST_2_5, |
|
36 | 33 | "to_affine_quantized_floatx", |
37 | 34 | "to_affine_quantized_intx_static", |
38 | 35 | "to_affine_quantized_floatx_static", |
39 | | - "to_affine_quantized_fpx", |
40 | 36 | ] |
41 | 37 |
|
42 | 38 |
|
@@ -126,40 +122,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor |
126 | 122 | if output_dtype is None: |
127 | 123 | output_dtype = self.dtype |
128 | 124 |
|
129 | | - from torchao.dtypes.floatx import FloatxTensorCoreLayout |
130 | | - |
131 | | - if isinstance(self._layout, FloatxTensorCoreLayout): |
132 | | - int_data, scale = self.tensor_impl.get_plain() |
133 | | - return dequantize_affine_floatx( |
134 | | - int_data, |
135 | | - scale, |
136 | | - self._layout.ebits, |
137 | | - self._layout.mbits, |
138 | | - output_dtype=output_dtype, |
139 | | - ) |
140 | | - else: |
141 | | - data, scale, zero_point = self.tensor_impl.get_plain() |
142 | | - dq = dequantize_affine( |
143 | | - data, |
144 | | - self.block_size, |
145 | | - scale, |
146 | | - zero_point, |
147 | | - data.dtype, |
148 | | - self.quant_min, |
149 | | - self.quant_max, |
150 | | - self.zero_point_domain, |
151 | | - output_dtype=output_dtype, |
152 | | - ) |
153 | | - from torchao.dtypes.uintx import TensorCoreTiledLayout |
| 125 | + data, scale, zero_point = self.tensor_impl.get_plain() |
| 126 | + dq = dequantize_affine( |
| 127 | + data, |
| 128 | + self.block_size, |
| 129 | + scale, |
| 130 | + zero_point, |
| 131 | + data.dtype, |
| 132 | + self.quant_min, |
| 133 | + self.quant_max, |
| 134 | + self.zero_point_domain, |
| 135 | + output_dtype=output_dtype, |
| 136 | + ) |
| 137 | + from torchao.dtypes.uintx import TensorCoreTiledLayout |
154 | 138 |
|
155 | | - if isinstance(self._layout, TensorCoreTiledLayout): |
156 | | - # need to return to original shape if tensor was padded |
157 | | - # in preprocessing |
158 | | - # TODO: we could add an API for this if there are more use cases |
159 | | - # (e.g. dequant_post_process) in TensorImpl or Layout |
160 | | - for dim, dim_size in enumerate(self.shape): |
161 | | - dq = dq.narrow(dim, 0, dim_size) |
162 | | - return dq |
| 139 | + if isinstance(self._layout, TensorCoreTiledLayout): |
| 140 | + # need to return to original shape if tensor was padded |
| 141 | + # in preprocessing |
| 142 | + # TODO: we could add an API for this if there are more use cases |
| 143 | + # (e.g. dequant_post_process) in TensorImpl or Layout |
| 144 | + for dim, dim_size in enumerate(self.shape): |
| 145 | + dq = dq.narrow(dim, 0, dim_size) |
| 146 | + return dq |
163 | 147 |
|
164 | 148 | def __tensor_flatten__(self): |
165 | 149 | return ["tensor_impl"], [ |
@@ -395,33 +379,6 @@ def from_hp_to_floatx_static( |
395 | 379 | f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" |
396 | 380 | ) |
397 | 381 |
|
398 | | - @classmethod |
399 | | - def from_hp_to_fpx( |
400 | | - cls, |
401 | | - input_float: torch.Tensor, |
402 | | - _layout: Layout, |
403 | | - ): |
404 | | - from torchao.dtypes.floatx import FloatxTensorCoreLayout |
405 | | - |
406 | | - assert isinstance( |
407 | | - _layout, FloatxTensorCoreLayout |
408 | | - ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" |
409 | | - original_shape = input_float.shape |
410 | | - input_float = _layout.pre_process(input_float) |
411 | | - # per axis quantization, where axis = 1 |
412 | | - block_size = list(input_float.shape) |
413 | | - block_size[1] = 1 |
414 | | - |
415 | | - ebits, mbits = _layout.ebits, _layout.mbits |
416 | | - # Note: these ops are hardcoded to have per axis quantization (axis=1) right now |
417 | | - scale = choose_qparams_affine_floatx(input_float, ebits, mbits) |
418 | | - floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) |
419 | | - floatx_packed = _layout.post_process(floatx_unpacked) |
420 | | - |
421 | | - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) |
422 | | - tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) |
423 | | - return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) |
424 | | - |
425 | 382 | @property |
426 | 383 | def _layout(self) -> Layout: |
427 | 384 | return self.tensor_impl._layout |
@@ -477,8 +434,6 @@ def _apply_fn_to_data(self, fn): |
477 | 434 | to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static |
478 | 435 | to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx |
479 | 436 | to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static |
480 | | -# experimental will be merged in to floatx |
481 | | -to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx |
482 | 437 |
|
483 | 438 | if TORCH_VERSION_AT_LEAST_2_5: |
484 | 439 | # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` |
|
0 commit comments