From 16d08f2da74b9a2594de721ca41ea228a00cd9be Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 22 Jan 2025 11:07:23 -0800 Subject: [PATCH 01/11] Update [ghstack-poisoned] --- torchao/quantization/quant_primitives.py | 75 +++++++++++++++++++++++- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index e587d4bc2b..1753bd9e22 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,10 +5,14 @@ # LICENSE file in the root directory of this source tree. import math -from enum import Enum, auto +from enum import auto, Enum from typing import Callable, Dict, List, Optional, Tuple, Union import torch +from torchao.float8.float8_utils import ( + ScalingGranularity, + tensor_to_scale as tensor_to_float8_scale, +) from torchao.prototype.custom_fp_utils import ( _f32_to_floatx_unpacked, @@ -16,11 +20,11 @@ _n_ones, ) from torchao.utils import ( + _is_float8_type, + _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, - _is_float8_type, - _register_custom_op, ) __all__ = [ @@ -39,6 +43,9 @@ "MappingType", "ZeroPointDomain", "TorchAODType", + "choose_qparams_affine_float8", + "quantize_affine_float8", + "dequantize_affine_float8", ] @@ -1300,3 +1307,65 @@ def dequantize_affine_floatx( tensor = tensor * scale.float().view(-1, 1) tensor = tensor.to(dtype=output_dtype) return tensor + + +def choose_qparams_affine_float8( + tensor: torch.Tensor, float8_dtype: torch.dtype +) -> torch.Tensor: + """ + Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now: + # https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416 + scale = tensor_to_float8_scale( + tensor, + float8_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=1, + ) + return scale + + +def quantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype, +) -> torch.Tensor: + """ + Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + scale (torch.Tensor): Scaling factor for the quantization. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to multiply with the scale + # In order to match numerics between eager and compile, we upcast manually here. + tensor_scaled = tensor.to(torch.float32) * scale + max_value = torch.finfo(float8_dtype).max + tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) + fp8_tensor = tensor_clamped.to(float8_dtype) + return fp8_tensor + + +def dequantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Dequantizes the float8 tensor to float32 tensor. + + Args: + tensor (torch.Tensor): Input float8 tensor to be dequantized. + scale (torch.Tensor): Scaling factor for the dequantization. + output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32). + """ + fp8_tensor = tensor.to(torch.float32) + hp_tensor = fp8_tensor / scale + return hp_tensor.to(output_dtype) From f642f6770fc487e2d99fabe7649d41ab09c94cca Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 22 Jan 2025 11:12:54 -0800 Subject: [PATCH 02/11] Update [ghstack-poisoned] --- torchao/quantization/quant_api.py | 20 +++++++------------- torchao/quantization/quant_primitives.py | 10 ++++++---- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..6560915813 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -66,21 +66,13 @@ Int8DynActInt4WeightGPTQQuantizer, Int8DynActInt4WeightQuantizer, ) -from .granularity import ( - PerRow, - PerTensor, -) +from .granularity import PerRow, PerTensor from .linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from .qat import ( - intx_quantization_aware_training, -) -from .quant_primitives import ( - MappingType, - ZeroPointDomain, -) +from .qat import intx_quantization_aware_training +from .quant_primitives import MappingType, ZeroPointDomain from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, Int8DynamicallyQuantizedLinearWeight, @@ -915,10 +907,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. + warnings.warn( + """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. from torchao.dtypes import SemiSparseLayout - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""") + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" + ) return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 1753bd9e22..906dbec73c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,26 +5,28 @@ # LICENSE file in the root directory of this source tree. import math -from enum import auto, Enum +from enum import Enum, auto from typing import Callable, Dict, List, Optional, Tuple, Union import torch + from torchao.float8.float8_utils import ( ScalingGranularity, +) +from torchao.float8.float8_utils import ( tensor_to_scale as tensor_to_float8_scale, ) - from torchao.prototype.custom_fp_utils import ( _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones, ) from torchao.utils import ( - _is_float8_type, - _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + _is_float8_type, + _register_custom_op, ) __all__ = [ From 1cbc0375de6a70f82790d86a546587375f268062 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 22 Jan 2025 11:25:42 -0800 Subject: [PATCH 03/11] Update [ghstack-poisoned] --- torchao/quantization/quant_api.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6560915813..b2eff196fd 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -66,13 +66,21 @@ Int8DynActInt4WeightGPTQQuantizer, Int8DynActInt4WeightQuantizer, ) -from .granularity import PerRow, PerTensor +from .granularity import ( + PerRow, + PerTensor, +) from .linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from .qat import intx_quantization_aware_training -from .quant_primitives import MappingType, ZeroPointDomain +from .qat import ( + intx_quantization_aware_training, +) +from .quant_primitives import ( + MappingType, + ZeroPointDomain, +) from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, Int8DynamicallyQuantizedLinearWeight, @@ -907,12 +915,10 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - warnings.warn( - """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. + warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. from torchao.dtypes import SemiSparseLayout - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" - ) + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""") return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) From 50612528544b58fb01080eeda81df9748ac9fbfc Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 15:15:40 -0800 Subject: [PATCH 04/11] Update [ghstack-poisoned] --- torchao/quantization/quant_primitives.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 906dbec73c..bacc8b2f6c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,15 +5,13 @@ # LICENSE file in the root directory of this source tree. import math -from enum import Enum, auto +from enum import auto, Enum from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torchao.float8.float8_utils import ( ScalingGranularity, -) -from torchao.float8.float8_utils import ( tensor_to_scale as tensor_to_float8_scale, ) from torchao.prototype.custom_fp_utils import ( @@ -22,11 +20,11 @@ _n_ones, ) from torchao.utils import ( + _is_float8_type, + _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, - _is_float8_type, - _register_custom_op, ) __all__ = [ @@ -1346,7 +1344,7 @@ def quantize_affine_float8( float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). """ # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically - # upcasted to `float32` to multiply with the scale + # upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization. # In order to match numerics between eager and compile, we upcast manually here. tensor_scaled = tensor.to(torch.float32) * scale max_value = torch.finfo(float8_dtype).max @@ -1361,13 +1359,16 @@ def dequantize_affine_float8( output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ - Dequantizes the float8 tensor to float32 tensor. + Dequantizes the float8 tensor to high precision tensor. Args: tensor (torch.Tensor): Input float8 tensor to be dequantized. scale (torch.Tensor): Scaling factor for the dequantization. output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32). """ + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization. + # In order to match numerics between eager and compile, we upcast manually here. fp8_tensor = tensor.to(torch.float32) hp_tensor = fp8_tensor / scale return hp_tensor.to(output_dtype) From 21afaa1a7c9cfdd3abcd3afa658653cbf25bf5d6 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 15:18:51 -0800 Subject: [PATCH 05/11] Update [ghstack-poisoned] --- torchao/quantization/quant_primitives.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index bacc8b2f6c..fd6acbe994 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,13 +5,15 @@ # LICENSE file in the root directory of this source tree. import math -from enum import auto, Enum +from enum import Enum, auto from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torchao.float8.float8_utils import ( ScalingGranularity, +) +from torchao.float8.float8_utils import ( tensor_to_scale as tensor_to_float8_scale, ) from torchao.prototype.custom_fp_utils import ( @@ -20,11 +22,11 @@ _n_ones, ) from torchao.utils import ( - _is_float8_type, - _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + _is_float8_type, + _register_custom_op, ) __all__ = [ From b1521bba20311afe23269ddf7922905f5b2052f3 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 24 Jan 2025 09:26:32 -0800 Subject: [PATCH 06/11] Update [ghstack-poisoned] --- test/quantization/test_quant_primitives.py | 63 ++++++++++++++++++++-- torchao/quantization/quant_primitives.py | 31 ++++++----- 2 files changed, 75 insertions(+), 19 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 102e76cb1a..2c8ff3adf1 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -9,16 +9,21 @@ import unittest import torch +from parameterized import parameterized from torchao.dtypes.utils import is_device +from torchao.float8.float8_utils import EPS as float8_eps from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, choose_qparams_affine, + choose_qparams_affine_float8, dequantize_affine, + dequantize_affine_float8, fake_quantize_affine, fake_quantize_affine_cachemask, + MappingType, quantize_affine, + quantize_affine_float8, + ZeroPointDomain, ) # TODO: remove test for utils? @@ -29,11 +34,11 @@ quantize_activation_per_token_absmax, ) from torchao.utils import ( + is_fbcode, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, - is_fbcode, ) _SEED = 1234 @@ -838,6 +843,58 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + @parameterized.expand( + [ + ( + torch.float32, + torch.float8_e4m3fn, + ), + ] + ) + def test_float8_quant_primitives(self, hp_dtype, float8_dtype): + input = torch.randn(10, 10) + + # float8 quantization primitives + scale = choose_qparams_affine_float8(input, float8_dtype) + quantized = quantize_affine_float8(input, scale, float8_dtype) + dequantized = dequantize_affine_float8(quantized, scale, hp_dtype) + + # reference implementation using generic primitives + expected_scale, _ = choose_qparams_affine( + input, + MappingType.SYMMETRIC, + input.shape, + float8_dtype, + eps=float8_eps, # use same EPS as float8 training + scale_dtype=torch.float32, + quant_min=torch.finfo(float8_dtype).min, + quant_max=torch.finfo(float8_dtype).max, + ) + expected_quantized = quantize_affine( + input, + input.shape, + scale, + output_dtype=float8_dtype, + quant_min=torch.finfo(float8_dtype).min, + quant_max=torch.finfo(float8_dtype).max, + zero_point=None, + zero_point_domain=None, + ) + expected_dequantized = dequantize_affine( + expected_quantized, + input.shape, + scale, + input_dtype=float8_dtype, + quant_min=torch.finfo(float8_dtype).min, + quant_max=torch.finfo(float8_dtype).max, + zero_point=None, + zero_point_domain=None, + ) + + self.assertTrue(torch.equal(expected_scale, scale)) + torch.testing.assert_close(expected_quantized, quantized) + torch.testing.assert_close(expected_dequantized, dequantized) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index fd6acbe994..1ba22a1ad8 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,15 +5,13 @@ # LICENSE file in the root directory of this source tree. import math -from enum import Enum, auto +from enum import auto, Enum from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torchao.float8.float8_utils import ( ScalingGranularity, -) -from torchao.float8.float8_utils import ( tensor_to_scale as tensor_to_float8_scale, ) from torchao.prototype.custom_fp_utils import ( @@ -22,11 +20,11 @@ _n_ones, ) from torchao.utils import ( + _is_float8_type, + _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, - _is_float8_type, - _register_custom_op, ) __all__ = [ @@ -1321,15 +1319,16 @@ def choose_qparams_affine_float8( tensor (torch.Tensor): Input tensor to be quantized. float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). """ - # NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now: - # https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416 - scale = tensor_to_float8_scale( - tensor, - float8_dtype, - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=1, - ) - return scale + # only tensorwise scaling is supported for now: + quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max + min_val_neg = torch.min(tensor) + max_val_pos = torch.max(tensor) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + return scale.to(dtype=torch.float32) + + # max_val_pos = torch.max(-min_val_neg, max_val_pos) + # scale = max_val_pos / (float(quant_max - quant_min) / 2) def quantize_affine_float8( @@ -1348,7 +1347,7 @@ def quantize_affine_float8( # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically # upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization. # In order to match numerics between eager and compile, we upcast manually here. - tensor_scaled = tensor.to(torch.float32) * scale + tensor_scaled = tensor.to(torch.float32) / scale max_value = torch.finfo(float8_dtype).max tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) fp8_tensor = tensor_clamped.to(float8_dtype) @@ -1372,5 +1371,5 @@ def dequantize_affine_float8( # upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization. # In order to match numerics between eager and compile, we upcast manually here. fp8_tensor = tensor.to(torch.float32) - hp_tensor = fp8_tensor / scale + hp_tensor = fp8_tensor * scale return hp_tensor.to(output_dtype) From 595581067dffb2810319180d10807a1d8981e388 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 24 Jan 2025 09:27:42 -0800 Subject: [PATCH 07/11] Update [ghstack-poisoned] --- torchao/quantization/quant_primitives.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 1ba22a1ad8..0a96360807 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,26 +5,22 @@ # LICENSE file in the root directory of this source tree. import math -from enum import auto, Enum +from enum import Enum, auto from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from torchao.float8.float8_utils import ( - ScalingGranularity, - tensor_to_scale as tensor_to_float8_scale, -) from torchao.prototype.custom_fp_utils import ( _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones, ) from torchao.utils import ( - _is_float8_type, - _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + _is_float8_type, + _register_custom_op, ) __all__ = [ From c1ff230434815f5e7a4784b36991ad9a5801f8d0 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 24 Jan 2025 09:29:42 -0800 Subject: [PATCH 08/11] Update [ghstack-poisoned] --- test/quantization/test_quant_primitives.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 2c8ff3adf1..47e2f3b388 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -14,16 +14,16 @@ from torchao.dtypes.utils import is_device from torchao.float8.float8_utils import EPS as float8_eps from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, choose_qparams_affine, choose_qparams_affine_float8, dequantize_affine, dequantize_affine_float8, fake_quantize_affine, fake_quantize_affine_cachemask, - MappingType, quantize_affine, quantize_affine_float8, - ZeroPointDomain, ) # TODO: remove test for utils? @@ -34,11 +34,11 @@ quantize_activation_per_token_absmax, ) from torchao.utils import ( - is_fbcode, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_fbcode, ) _SEED = 1234 From 16535428558e4d11143cf637722ee705e61044a2 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 24 Jan 2025 09:31:15 -0800 Subject: [PATCH 09/11] Update [ghstack-poisoned] --- torchao/quantization/quant_primitives.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 0a96360807..04543aca23 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from enum import Enum, auto +from enum import auto, Enum from typing import Callable, Dict, List, Optional, Tuple, Union import torch @@ -16,11 +16,11 @@ _n_ones, ) from torchao.utils import ( + _is_float8_type, + _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, - _is_float8_type, - _register_custom_op, ) __all__ = [ @@ -1323,9 +1323,6 @@ def choose_qparams_affine_float8( scale = max_val_pos / (float(quant_max - quant_min) / 2) return scale.to(dtype=torch.float32) - # max_val_pos = torch.max(-min_val_neg, max_val_pos) - # scale = max_val_pos / (float(quant_max - quant_min) / 2) - def quantize_affine_float8( tensor: torch.Tensor, From fd57fdfc5f4a685f4a6b8bf824ced16881b47a7f Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 24 Jan 2025 09:37:48 -0800 Subject: [PATCH 10/11] Update [ghstack-poisoned] --- test/quantization/test_quant_primitives.py | 25 ++++++++++++++++------ torchao/quantization/quant_primitives.py | 4 ++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 47e2f3b388..7447f5afc8 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -14,16 +14,16 @@ from torchao.dtypes.utils import is_device from torchao.float8.float8_utils import EPS as float8_eps from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, choose_qparams_affine, choose_qparams_affine_float8, dequantize_affine, dequantize_affine_float8, fake_quantize_affine, fake_quantize_affine_cachemask, + MappingType, quantize_affine, quantize_affine_float8, + ZeroPointDomain, ) # TODO: remove test for utils? @@ -34,11 +34,11 @@ quantize_activation_per_token_absmax, ) from torchao.utils import ( + is_fbcode, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, - is_fbcode, ) _SEED = 1234 @@ -849,15 +849,27 @@ def test_fake_quantize_affine_cachemask(self): torch.float32, torch.float8_e4m3fn, ), + ( + torch.float32, + torch.float8_e5m2, + ), + ( + torch.bfloat16, + torch.float8_e4m3fn, + ), + ( + torch.bfloat16, + torch.float8_e5m2, + ), ] ) def test_float8_quant_primitives(self, hp_dtype, float8_dtype): input = torch.randn(10, 10) # float8 quantization primitives - scale = choose_qparams_affine_float8(input, float8_dtype) - quantized = quantize_affine_float8(input, scale, float8_dtype) - dequantized = dequantize_affine_float8(quantized, scale, hp_dtype) + scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype) + quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype) + dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype) # reference implementation using generic primitives expected_scale, _ = choose_qparams_affine( @@ -885,6 +897,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype): input.shape, scale, input_dtype=float8_dtype, + output_dtype=hp_dtype, quant_min=torch.finfo(float8_dtype).min, quant_max=torch.finfo(float8_dtype).max, zero_point=None, diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 04543aca23..ea372a3ceb 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1306,7 +1306,7 @@ def dequantize_affine_floatx( def choose_qparams_affine_float8( - tensor: torch.Tensor, float8_dtype: torch.dtype + tensor: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: """ Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. @@ -1327,7 +1327,7 @@ def choose_qparams_affine_float8( def quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype, + float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. From c6cdbababd5bc5ffadeb24e20d4c6bcd89044344 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 24 Jan 2025 09:38:55 -0800 Subject: [PATCH 11/11] Update [ghstack-poisoned] --- test/quantization/test_quant_primitives.py | 6 +++--- torchao/quantization/quant_primitives.py | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 7447f5afc8..77616c1c6a 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -14,16 +14,16 @@ from torchao.dtypes.utils import is_device from torchao.float8.float8_utils import EPS as float8_eps from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, choose_qparams_affine, choose_qparams_affine_float8, dequantize_affine, dequantize_affine_float8, fake_quantize_affine, fake_quantize_affine_cachemask, - MappingType, quantize_affine, quantize_affine_float8, - ZeroPointDomain, ) # TODO: remove test for utils? @@ -34,11 +34,11 @@ quantize_activation_per_token_absmax, ) from torchao.utils import ( - is_fbcode, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_fbcode, ) _SEED = 1234 diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index ea372a3ceb..8b0ce28434 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from enum import auto, Enum +from enum import Enum, auto from typing import Callable, Dict, List, Optional, Tuple, Union import torch @@ -16,11 +16,11 @@ _n_ones, ) from torchao.utils import ( - _is_float8_type, - _register_custom_op, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + _is_float8_type, + _register_custom_op, ) __all__ = [ @@ -1306,7 +1306,8 @@ def dequantize_affine_floatx( def choose_qparams_affine_float8( - tensor: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, + tensor: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: """ Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.