Skip to content

Commit

Permalink
[Common] Unified Scales for SDPA (#3205)
Browse files Browse the repository at this point in the history
### Changes

Include SDPA metatype in scales unification map for all MinMax backends

### Reason for changes

Scales were not being unified in quantizers inserted for such a
subgraph:

```
x  y
\  |
 concat   Q   V
     |   /  /
      SDPA  
```

### Tests

Template test was created at
`tests/cross_fw/test_templates/test_unified_scales.py`
The template tests uses a synthetic SDPA model with a concat operation.
Then, it uses the method `_find_quantization_target_points` of
MinMaxQuantization algorithm to return the unified scale groups. These
groups are then used for assertion.
  • Loading branch information
anzr299 authored Jan 22, 2025
1 parent 9c5b459 commit 874f16a
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 4 deletions.
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes}
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes}

@property
def hw_config(self) -> HWConfig:
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.OVConcatMetatype: self.overflow_fix_metatypes}
return {om.OVConcatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes}

@property
def hw_config(self) -> HWConfig:
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.PTCatMetatype: self.overflow_fix_metatypes}
return {om.PTCatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes}

@property
def hw_config(self) -> HWConfig:
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.PTCatMetatype: self.overflow_fix_metatypes}
return {om.PTCatMetatype: self.overflow_fix_metatypes + self.scaled_dot_product_attention_metatypes}

@property
def hw_config(self) -> HWConfig:
Expand Down
51 changes: 51 additions & 0 deletions tests/cross_fw/test_templates/test_unified_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import List, TypeVar

import pytest
import torch

from nncf.common.factory import NNCFGraphFactory
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from tests.torch.test_models.synthetic import ConcatSDPABlock

TModel = TypeVar("TModel")


class TemplateTestUnifiedScales:
@property
@abstractmethod
def get_backend_specific_model(self, model: TModel) -> TModel:
"""
Convert and return backend specific Model
:param model: Model (for example in PT) to be converted to backend specific model
:return: Backend specific Model
"""

@pytest.mark.parametrize(
"model_cls, unified_group, unified_group_nncf_network",
((ConcatSDPABlock, [["x", "y"]], [["/nncf_model_input_0", "/nncf_model_input_1"]]),),
)
def test_unified_groups(
self, model_cls: TModel, unified_group: List[List[str]], unified_group_nncf_network: List[List[str]]
):
backend_model = self.get_backend_specific_model(model_cls())
if isinstance(backend_model, torch.nn.Module) and not isinstance(backend_model, torch.fx.GraphModule):
unified_group = unified_group_nncf_network

nncf_graph = NNCFGraphFactory.create(backend_model)
algo = MinMaxQuantization()
algo._set_backend_entity(backend_model)
_, groups = algo._get_quantization_target_points(backend_model, nncf_graph)
assert [[target.target_node_name for target in groups] for groups in groups] == unified_group
29 changes: 29 additions & 0 deletions tests/openvino/native/test_unified_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import openvino as ov
import torch

from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales


class TestUnifiedScales(TemplateTestUnifiedScales):
def get_backend_specific_model(self, model: torch.nn.Module) -> ov.Model:
input_shape = model.INPUT_SHAPE
backend_model = ov.convert_model(
model,
example_input=(
torch.randn(input_shape),
torch.randn(input_shape),
),
)

return backend_model
30 changes: 30 additions & 0 deletions tests/torch/fx/test_unified_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from nncf.torch.nncf_network import NNCFNetwork
from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales
from tests.torch.fx.helpers import get_torch_fx_model_q_transformed


class TestUnifiedScales(TemplateTestUnifiedScales):
def get_backend_specific_model(self, model: torch.nn.Module) -> NNCFNetwork:
input_shape = model.INPUT_SHAPE
backend_model = get_torch_fx_model_q_transformed(
model,
(
torch.randn(input_shape),
torch.randn(input_shape),
),
)

return backend_model
18 changes: 18 additions & 0 deletions tests/torch/quantization/test_unified_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from nncf.common.quantization.structs import NonWeightQuantizerId
from nncf.torch.dynamic_graph.operation_address import OperationAddress
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_creation import wrap_model
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.layers import AsymmetricQuantizer
from tests.cross_fw.test_templates.test_unified_scales import TemplateTestUnifiedScales
from tests.torch.helpers import create_compressed_model_and_algo_for_test
from tests.torch.helpers import get_nodes_by_type
from tests.torch.helpers import register_bn_adaptation_init_args
Expand Down Expand Up @@ -711,3 +714,18 @@ def test_unified_scales_with_shared_nodes():

assert len(compression_ctrl.weight_quantizers) == 1 # The two embedding nodes point to a single shared layer
assert len(compression_ctrl.non_weight_quantizers) == 0 # The "add" operation has its inputs already quantized


class TestUnifiedScales(TemplateTestUnifiedScales):
def get_backend_specific_model(self, model: torch.nn.Module) -> NNCFNetwork:
input_shape = model.INPUT_SHAPE
backend_model = wrap_model(
model,
(
torch.randn(input_shape),
torch.randn(input_shape),
),
trace_parameters=True,
)

return backend_model
16 changes: 16 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,19 @@ def forward(self, x):
kq /= 2**-2
kq = torch.softmax(kq, -1)
return torch.matmul(torch.transpose(kq, 1, 2), v)


class ConcatSDPABlock(torch.nn.Module):
INPUT_SHAPE = (2, 10, 6)

def __init__(self):
super().__init__()

def forward(self, x, y):
concatenated_input = torch.cat((x, y), dim=-1)
query = concatenated_input
key = concatenated_input
value = concatenated_input
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, dropout_p=0.2)

return attn_output

0 comments on commit 874f16a

Please sign in to comment.