From 874f16a08eac3abea0c5053517a04a4b0f33eb08 Mon Sep 17 00:00:00 2001 From: Aamir Nazir Date: Wed, 22 Jan 2025 15:56:58 +0400 Subject: [PATCH] [Common] Unified Scales for SDPA (#3205) ### Changes Include SDPA metatype in scales unification map for all MinMax backends ### Reason for changes Scales were not being unified in quantizers inserted for such a subgraph: ``` x y \ | concat Q V | / / SDPA ``` ### Tests Template test was created at `tests/cross_fw/test_templates/test_unified_scales.py` The template tests uses a synthetic SDPA model with a concat operation. Then, it uses the method `_find_quantization_target_points` of MinMaxQuantization algorithm to return the unified scale groups. These groups are then used for assertion. --- .../algorithms/min_max/onnx_backend.py | 2 +- .../algorithms/min_max/openvino_backend.py | 2 +- .../algorithms/min_max/torch_backend.py | 2 +- .../algorithms/min_max/torch_fx_backend.py | 2 +- .../test_templates/test_unified_scales.py | 51 +++++++++++++++++++ tests/openvino/native/test_unified_scales.py | 29 +++++++++++ tests/torch/fx/test_unified_scales.py | 30 +++++++++++ .../torch/quantization/test_unified_scales.py | 18 +++++++ tests/torch/test_models/synthetic.py | 16 ++++++ 9 files changed, 148 insertions(+), 4 deletions(-) create mode 100644 tests/cross_fw/test_templates/test_unified_scales.py create mode 100644 tests/openvino/native/test_unified_scales.py create mode 100644 tests/torch/fx/test_unified_scales.py diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index 1f04358cc9c..c6cd199ee7f 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -97,7 +97,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: - return {om.ONNXConcatMetatype: self.overflow_fix_metatypes} + return {om.ONNXConcatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes} @property def hw_config(self) -> HWConfig: diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index fb0525a5e14..ba7d2122e52 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -95,7 +95,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: - return {om.OVConcatMetatype: self.overflow_fix_metatypes} + return {om.OVConcatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes} @property def hw_config(self) -> HWConfig: diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index f0882ffa92c..f74336b07a0 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -116,7 +116,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: - return {om.PTCatMetatype: self.overflow_fix_metatypes} + return {om.PTCatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes} @property def hw_config(self) -> HWConfig: diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 9a406f435dd..9336d872f34 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -110,7 +110,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: - return {om.PTCatMetatype: self.overflow_fix_metatypes} + return {om.PTCatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes} @property def hw_config(self) -> HWConfig: diff --git a/tests/cross_fw/test_templates/test_unified_scales.py b/tests/cross_fw/test_templates/test_unified_scales.py new file mode 100644 index 00000000000..4cb11f3dc6b --- /dev/null +++ b/tests/cross_fw/test_templates/test_unified_scales.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import List, TypeVar + +import pytest +import torch + +from nncf.common.factory import NNCFGraphFactory +from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization +from tests.torch.test_models.synthetic import ConcatSDPABlock + +TModel = TypeVar("TModel") + + +class TemplateTestUnifiedScales: + @property + @abstractmethod + def get_backend_specific_model(self, model: TModel) -> TModel: + """ + Convert and return backend specific Model + + :param model: Model (for example in PT) to be converted to backend specific model + :return: Backend specific Model + """ + + @pytest.mark.parametrize( + "model_cls, unified_group, unified_group_nncf_network", + ((ConcatSDPABlock, [["x", "y"]], [["/nncf_model_input_0", "/nncf_model_input_1"]]),), + ) + def test_unified_groups( + self, model_cls: TModel, unified_group: List[List[str]], unified_group_nncf_network: List[List[str]] + ): + backend_model = self.get_backend_specific_model(model_cls()) + if isinstance(backend_model, torch.nn.Module) and not isinstance(backend_model, torch.fx.GraphModule): + unified_group = unified_group_nncf_network + + nncf_graph = NNCFGraphFactory.create(backend_model) + algo = MinMaxQuantization() + algo._set_backend_entity(backend_model) + _, groups = algo._get_quantization_target_points(backend_model, nncf_graph) + assert [[target.target_node_name for target in groups] for groups in groups] == unified_group diff --git a/tests/openvino/native/test_unified_scales.py b/tests/openvino/native/test_unified_scales.py new file mode 100644 index 00000000000..01aa1e75238 --- /dev/null +++ b/tests/openvino/native/test_unified_scales.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import openvino as ov +import torch + +from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales + + +class TestUnifiedScales(TemplateTestUnifiedScales): + def get_backend_specific_model(self, model: torch.nn.Module) -> ov.Model: + input_shape = model.INPUT_SHAPE + backend_model = ov.convert_model( + model, + example_input=( + torch.randn(input_shape), + torch.randn(input_shape), + ), + ) + + return backend_model diff --git a/tests/torch/fx/test_unified_scales.py b/tests/torch/fx/test_unified_scales.py new file mode 100644 index 00000000000..1f3a895cf48 --- /dev/null +++ b/tests/torch/fx/test_unified_scales.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nncf.torch.nncf_network import NNCFNetwork +from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales +from tests.torch.fx.helpers import get_torch_fx_model_q_transformed + + +class TestUnifiedScales(TemplateTestUnifiedScales): + def get_backend_specific_model(self, model: torch.nn.Module) -> NNCFNetwork: + input_shape = model.INPUT_SHAPE + backend_model = get_torch_fx_model_q_transformed( + model, + ( + torch.randn(input_shape), + torch.randn(input_shape), + ), + ) + + return backend_model diff --git a/tests/torch/quantization/test_unified_scales.py b/tests/torch/quantization/test_unified_scales.py index b9105b5c4ea..5df0af878d4 100644 --- a/tests/torch/quantization/test_unified_scales.py +++ b/tests/torch/quantization/test_unified_scales.py @@ -27,7 +27,10 @@ from nncf.common.quantization.structs import NonWeightQuantizerId from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.model_creation import wrap_model +from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.layers import AsymmetricQuantizer +from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import get_nodes_by_type from tests.torch.helpers import register_bn_adaptation_init_args @@ -711,3 +714,18 @@ def test_unified_scales_with_shared_nodes(): assert len(compression_ctrl.weight_quantizers) == 1 # The two embedding nodes point to a single shared layer assert len(compression_ctrl.non_weight_quantizers) == 0 # The "add" operation has its inputs already quantized + + +class TestUnifiedScales(TemplateTestUnifiedScales): + def get_backend_specific_model(self, model: torch.nn.Module) -> NNCFNetwork: + input_shape = model.INPUT_SHAPE + backend_model = wrap_model( + model, + ( + torch.randn(input_shape), + torch.randn(input_shape), + ), + trace_parameters=True, + ) + + return backend_model diff --git a/tests/torch/test_models/synthetic.py b/tests/torch/test_models/synthetic.py index ba9c385f41d..096bd1efb51 100644 --- a/tests/torch/test_models/synthetic.py +++ b/tests/torch/test_models/synthetic.py @@ -662,3 +662,19 @@ def forward(self, x): kq /= 2**-2 kq = torch.softmax(kq, -1) return torch.matmul(torch.transpose(kq, 1, 2), v) + + +class ConcatSDPABlock(torch.nn.Module): + INPUT_SHAPE = (2, 10, 6) + + def __init__(self): + super().__init__() + + def forward(self, x, y): + concatenated_input = torch.cat((x, y), dim=-1) + query = concatenated_input + key = concatenated_input + value = concatenated_input + attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, dropout_p=0.2) + + return attn_output