Skip to content

Commit 1744c3d

Browse files
authored
Fix hessian criteria statistics collector (#3072)
### Changes Fix bug with passing subset_size ### Tests Add a test for backend methods
1 parent 6962892 commit 1744c3d

File tree

6 files changed

+188
-23
lines changed

6 files changed

+188
-23
lines changed

nncf/experimental/common/tensor_statistics/collectors.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,21 @@ def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]:
470470
return [fns.mean(x, reduction_axes, keepdims=self._keepdims)]
471471

472472

473+
class MeanVarianceReducer(TensorReducerBase):
474+
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
475+
raise NotImplementedError()
476+
477+
478+
class MaxVarianceReducer(TensorReducerBase):
479+
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
480+
raise NotImplementedError()
481+
482+
483+
class MeanAbsMaxReducer(TensorReducerBase):
484+
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
485+
raise NotImplementedError()
486+
487+
473488
class QuantileReducerBase(TensorReducerBase):
474489
def __init__(
475490
self,

nncf/openvino/statistics/collectors.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import List, Optional
12+
from typing import Optional
1313

14-
from nncf.common.tensor import TensorType
1514
from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer
1615
from nncf.experimental.common.tensor_statistics.collectors import AbsQuantileReducer
1716
from nncf.experimental.common.tensor_statistics.collectors import BatchMeanReducer
1817
from nncf.experimental.common.tensor_statistics.collectors import InplaceInsertionFNType
1918
from nncf.experimental.common.tensor_statistics.collectors import MaxReducer
19+
from nncf.experimental.common.tensor_statistics.collectors import MaxVarianceReducer
20+
from nncf.experimental.common.tensor_statistics.collectors import MeanAbsMaxReducer
2021
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
2122
from nncf.experimental.common.tensor_statistics.collectors import MeanPerChReducer
2223
from nncf.experimental.common.tensor_statistics.collectors import MeanReducer
24+
from nncf.experimental.common.tensor_statistics.collectors import MeanVarianceReducer
2325
from nncf.experimental.common.tensor_statistics.collectors import MinReducer
2426
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
2527
from nncf.experimental.common.tensor_statistics.collectors import NoopReducer
@@ -28,7 +30,6 @@
2830
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
2931
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
3032
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
31-
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
3233
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
3334
from nncf.experimental.common.tensor_statistics.statistics import RawTensorStatistic
3435
from nncf.openvino.graph.node_utils import get_inplace_batch_mean_op
@@ -67,26 +68,17 @@ def get_inplace_fn(self):
6768
return get_inplace_mean_op(self._reduction_axes)
6869

6970

70-
class OVMeanVarianceReducer(TensorReducerBase):
71-
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
72-
raise NotImplementedError()
73-
71+
class OVMeanVarianceReducer(MeanVarianceReducer):
7472
def get_inplace_fn(self):
7573
return get_inplace_mean_var_op(self._reduction_axes)
7674

7775

78-
class OVMaxVarianceReducer(TensorReducerBase):
79-
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
80-
raise NotImplementedError()
81-
76+
class OVMaxVarianceReducer(MaxVarianceReducer):
8277
def get_inplace_fn(self):
8378
return get_inplace_max_var_op(self._reduction_axes)
8479

8580

86-
class OVMeanAbsMaxReducer(TensorReducerBase):
87-
def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
88-
raise NotImplementedError()
89-
81+
class OVMeanAbsMaxReducer(MeanAbsMaxReducer):
9082
def get_inplace_fn(self):
9183
return get_inplace_mean_max_op(self._reduction_axes, True)
9284

nncf/quantization/algorithms/weight_compression/mixed_precision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def get_statistic_points(
281281
return statistic_container
282282

283283
@abstractmethod
284-
def _get_statistic_collector():
284+
def _get_statistic_collector(self):
285285
"""
286286
Get statistic collector
287287
"""
@@ -360,7 +360,7 @@ def _calc_weight_sensitivity(
360360
return fns.linalg.norm(decompressed_weight - weight, ord="fro").item()
361361

362362
def _get_statistic_collector(self):
363-
return self._backend_entity.hawq_statistic_collector()
363+
return self._backend_entity.hawq_statistic_collector(self._subset_size)
364364

365365

366366
@MIXED_PRECISION_CRITERIA.register(SensitivityMetric.MEAN_ACTIVATION_VARIANCE)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from abc import abstractmethod
13+
14+
import pytest
15+
16+
from nncf.experimental.common.tensor_statistics.collectors import HAWQAggregator
17+
from nncf.experimental.common.tensor_statistics.collectors import MaxVarianceReducer
18+
from nncf.experimental.common.tensor_statistics.collectors import MeanAbsMaxReducer
19+
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
20+
from nncf.experimental.common.tensor_statistics.collectors import MeanVarianceReducer
21+
from nncf.experimental.common.tensor_statistics.collectors import NoopReducer
22+
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
23+
24+
25+
class TemplateTestMixedPrecisionAlgoBackend:
26+
@abstractmethod
27+
def get_hawq_with_backend(self, subset_size: int):
28+
"""Returns a HAWQ instance of the algorithm."""
29+
30+
@abstractmethod
31+
def get_mean_variance_with_backend(self, subset_size: int):
32+
"""Returns a Mean Variance instance of the algorithm."""
33+
34+
@abstractmethod
35+
def get_max_variance_with_backend(self, subset_size: int):
36+
"""Returns a Max Variance instance of the algorithm."""
37+
38+
@abstractmethod
39+
def get_mean_max_with_backend(self, subset_size: int):
40+
"""Returns a Mean Max instance of the algorithm."""
41+
42+
def check_aggregator(self, collector: TensorCollector, expected_aggregator_type, subset_size: int):
43+
assert len(collector.aggregators) == 1, "Collector should have exactly one aggregator."
44+
_, aggregator = collector.aggregators.popitem()
45+
assert isinstance(
46+
aggregator, expected_aggregator_type
47+
), f"Expected aggregator of type {expected_aggregator_type.__name__}, got {type(aggregator).__name__}."
48+
assert aggregator.num_samples == subset_size, "Aggregator num_samples does not match the provided subset size."
49+
50+
def check_reducer(self, collector: TensorCollector, expected_reducer_type):
51+
assert len(collector.reducers) == 1
52+
reducer = collector.reducers.pop()
53+
assert isinstance(
54+
reducer, expected_reducer_type
55+
), f"Expected reducer of type {expected_reducer_type.__name__}, got {type(reducer).__name__}."
56+
57+
@pytest.mark.parametrize("subset_size", [1, 10, None])
58+
@pytest.mark.parametrize(
59+
"algo_func, aggregator_type, reducer_type",
60+
[
61+
("get_hawq_with_backend", HAWQAggregator, NoopReducer),
62+
("get_mean_variance_with_backend", MeanAggregator, MeanVarianceReducer),
63+
("get_max_variance_with_backend", MeanAggregator, MaxVarianceReducer),
64+
("get_mean_max_with_backend", MeanAggregator, MeanAbsMaxReducer),
65+
],
66+
)
67+
def test_statistic_collector(self, subset_size, algo_func, aggregator_type, reducer_type):
68+
"""Test function to validate statistic collectors."""
69+
algo = getattr(self, algo_func)(subset_size)
70+
collector = algo._get_statistic_collector()
71+
72+
# Verify the collector instance and properties
73+
assert isinstance(collector, TensorCollector), "Collector is not an instance of TensorCollector."
74+
75+
# Validate the aggregator and reducer types
76+
self.check_aggregator(collector, aggregator_type, subset_size)
77+
self.check_reducer(collector, reducer_type)

tests/openvino/native/quantization/test_weights_compression.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -899,12 +899,54 @@ def test_compression_for_different_dtypes(activation_dtype, weight_dtype):
899899
)
900900
@pytest.mark.parametrize(
901901
("compression_args", "multiplier_of_calls"),
902-
(
903-
(dict(mode=CompressWeightsMode.INT4_ASYM, ratio=1), 0), # data-free, no reducers
904-
(dict(mode=CompressWeightsMode.INT4_ASYM, ratio=0.5), 1), # 1 reducer for mixed precision
905-
(dict(mode=CompressWeightsMode.INT4_ASYM, ratio=1, awq=True), 2), # mean & shape reducer for AWQ
906-
(dict(mode=CompressWeightsMode.INT4_ASYM, ratio=0.5, awq=True), 3), # 2 - for AWQ + 1 - for Mixed Precision
907-
),
902+
[
903+
({"mode": CompressWeightsMode.INT4_ASYM, "ratio": 1}, 0), # data-free, no reducers
904+
({"mode": CompressWeightsMode.INT4_ASYM, "ratio": 1, "awq": True}, 2), # mean & shape reducer for AWQ
905+
(
906+
{"mode": CompressWeightsMode.INT4_ASYM, "ratio": 0.5, "awq": True},
907+
3,
908+
), # 2 - for AWQ + 1 - for Mixed Precision
909+
(
910+
{
911+
"mode": CompressWeightsMode.INT4_ASYM,
912+
"ratio": 0.5,
913+
"sensitivity_metric": nncf.SensitivityMetric.HESSIAN_INPUT_ACTIVATION,
914+
},
915+
1,
916+
), # 1 reducer for mixed precision
917+
(
918+
{
919+
"mode": CompressWeightsMode.INT4_ASYM,
920+
"ratio": 0.5,
921+
"sensitivity_metric": nncf.SensitivityMetric.MEAN_ACTIVATION_VARIANCE,
922+
},
923+
1,
924+
), # 1 reducer for mixed precision
925+
(
926+
{
927+
"mode": CompressWeightsMode.INT4_ASYM,
928+
"ratio": 0.5,
929+
"sensitivity_metric": nncf.SensitivityMetric.MAX_ACTIVATION_VARIANCE,
930+
},
931+
1,
932+
), # 1 reducer for mixed precision
933+
(
934+
{
935+
"mode": CompressWeightsMode.INT4_ASYM,
936+
"ratio": 0.5,
937+
"sensitivity_metric": nncf.SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE,
938+
},
939+
1,
940+
), # 1 reducer for mixed precision
941+
(
942+
{
943+
"mode": CompressWeightsMode.INT4_ASYM,
944+
"ratio": 0.5,
945+
"sensitivity_metric": nncf.SensitivityMetric.WEIGHT_QUANTIZATION_ERROR,
946+
},
947+
0,
948+
), # 0 - data-free method
949+
],
908950
)
909951
def test_number_of_reduced_statistics_for_subset_size(
910952
mocker, dataset_size, subset_size, ref_size, compression_args, multiplier_of_calls
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
from nncf.quantization.algorithms.weight_compression.mixed_precision import HAWQCriterion
12+
from nncf.quantization.algorithms.weight_compression.mixed_precision import MaxVarianceCriterion
13+
from nncf.quantization.algorithms.weight_compression.mixed_precision import MeanMaxCriterion
14+
from nncf.quantization.algorithms.weight_compression.mixed_precision import MeanVarianceCriterion
15+
from nncf.quantization.algorithms.weight_compression.openvino_backend import OVMixedPrecisionAlgoBackend
16+
from tests.cross_fw.test_templates.test_weights_compression_backends import TemplateTestMixedPrecisionAlgoBackend
17+
from tests.openvino.native.models import IdentityMatmul
18+
19+
20+
class TestOVMixedPrecisionAlgoBackend(TemplateTestMixedPrecisionAlgoBackend):
21+
def get_hawq_with_backend(self, subset_size):
22+
hawq = HAWQCriterion(None, None, subset_size=subset_size)
23+
hawq._backend_entity = OVMixedPrecisionAlgoBackend(IdentityMatmul().ov_model)
24+
return hawq
25+
26+
def get_mean_variance_with_backend(self, subset_size: int):
27+
mean_variance = MeanVarianceCriterion(None, None, subset_size=subset_size)
28+
mean_variance._backend_entity = OVMixedPrecisionAlgoBackend(IdentityMatmul().ov_model)
29+
return mean_variance
30+
31+
def get_max_variance_with_backend(self, subset_size: int):
32+
max_variance = MaxVarianceCriterion(None, None, subset_size=subset_size)
33+
max_variance._backend_entity = OVMixedPrecisionAlgoBackend(IdentityMatmul().ov_model)
34+
return max_variance
35+
36+
def get_mean_max_with_backend(self, subset_size: int):
37+
mean_max_variance = MeanMaxCriterion(None, None, subset_size=subset_size)
38+
mean_max_variance._backend_entity = OVMixedPrecisionAlgoBackend(IdentityMatmul().ov_model)
39+
return mean_max_variance

0 commit comments

Comments
 (0)