-
Notifications
You must be signed in to change notification settings - Fork 244
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Common] Unified Scales for SDPA (#3205)
### Changes Include SDPA metatype in scales unification map for all MinMax backends ### Reason for changes Scales were not being unified in quantizers inserted for such a subgraph: ``` x y \ | concat Q V | / / SDPA ``` ### Tests Template test was created at `tests/cross_fw/test_templates/test_unified_scales.py` The template tests uses a synthetic SDPA model with a concat operation. Then, it uses the method `_find_quantization_target_points` of MinMaxQuantization algorithm to return the unified scale groups. These groups are then used for assertion.
- Loading branch information
Showing
9 changed files
with
148 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import abstractmethod | ||
from typing import List, TypeVar | ||
|
||
import pytest | ||
import torch | ||
|
||
from nncf.common.factory import NNCFGraphFactory | ||
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization | ||
from tests.torch.test_models.synthetic import ConcatSDPABlock | ||
|
||
TModel = TypeVar("TModel") | ||
|
||
|
||
class TemplateTestUnifiedScales: | ||
@property | ||
@abstractmethod | ||
def get_backend_specific_model(self, model: TModel) -> TModel: | ||
""" | ||
Convert and return backend specific Model | ||
:param model: Model (for example in PT) to be converted to backend specific model | ||
:return: Backend specific Model | ||
""" | ||
|
||
@pytest.mark.parametrize( | ||
"model_cls, unified_group, unified_group_nncf_network", | ||
((ConcatSDPABlock, [["x", "y"]], [["/nncf_model_input_0", "/nncf_model_input_1"]]),), | ||
) | ||
def test_unified_groups( | ||
self, model_cls: TModel, unified_group: List[List[str]], unified_group_nncf_network: List[List[str]] | ||
): | ||
backend_model = self.get_backend_specific_model(model_cls()) | ||
if isinstance(backend_model, torch.nn.Module) and not isinstance(backend_model, torch.fx.GraphModule): | ||
unified_group = unified_group_nncf_network | ||
|
||
nncf_graph = NNCFGraphFactory.create(backend_model) | ||
algo = MinMaxQuantization() | ||
algo._set_backend_entity(backend_model) | ||
_, groups = algo._get_quantization_target_points(backend_model, nncf_graph) | ||
assert [[target.target_node_name for target in groups] for groups in groups] == unified_group |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import openvino as ov | ||
import torch | ||
|
||
from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales | ||
|
||
|
||
class TestUnifiedScales(TemplateTestUnifiedScales): | ||
def get_backend_specific_model(self, model: torch.nn.Module) -> ov.Model: | ||
input_shape = model.INPUT_SHAPE | ||
backend_model = ov.convert_model( | ||
model, | ||
example_input=( | ||
torch.randn(input_shape), | ||
torch.randn(input_shape), | ||
), | ||
) | ||
|
||
return backend_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
|
||
from nncf.torch.nncf_network import NNCFNetwork | ||
from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales | ||
from tests.torch.fx.helpers import get_torch_fx_model_q_transformed | ||
|
||
|
||
class TestUnifiedScales(TemplateTestUnifiedScales): | ||
def get_backend_specific_model(self, model: torch.nn.Module) -> NNCFNetwork: | ||
input_shape = model.INPUT_SHAPE | ||
backend_model = get_torch_fx_model_q_transformed( | ||
model, | ||
( | ||
torch.randn(input_shape), | ||
torch.randn(input_shape), | ||
), | ||
) | ||
|
||
return backend_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters