Skip to content

Commit f355847

Browse files
[mypy] hwconfig (#3189)
### Changes Enable mypy check for `nncf/common/hardware/config.py`
1 parent 2a5ee2a commit f355847

File tree

4 files changed

+25
-22
lines changed

4 files changed

+25
-22
lines changed

nncf/common/hardware/config.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pathlib import Path
1616
from typing import Any, Dict, List, Optional, Set, Type
1717

18-
import jstyleson as json
18+
import jstyleson as json # type: ignore[import-untyped]
1919

2020
import nncf
2121
from nncf.common.graph.operator_metatypes import OperatorMetatype
@@ -60,7 +60,7 @@ def get_hw_config_type(target_device: str) -> Optional[HWConfigType]:
6060
return HWConfigType(HW_CONFIG_TYPE_TARGET_DEVICE_MAP[target_device])
6161

6262

63-
class HWConfig(list, ABC):
63+
class HWConfig(list[Dict[str, Any]], ABC):
6464
QUANTIZATION_ALGORITHM_NAME = "quantization"
6565
ATTRIBUTES_NAME = "attributes"
6666
SCALE_ATTRIBUTE_NAME = "scales"
@@ -69,23 +69,23 @@ class HWConfig(list, ABC):
6969

7070
TYPE_TO_CONF_NAME_DICT = {HWConfigType.CPU: "cpu.json", HWConfigType.NPU: "npu.json", HWConfigType.GPU: "gpu.json"}
7171

72-
def __init__(self):
72+
def __init__(self) -> None:
7373
super().__init__()
74-
self.registered_algorithm_configs = {}
74+
self.registered_algorithm_configs: Dict[str, Any] = {}
7575
self.target_device = None
7676

7777
@abstractmethod
7878
def _get_available_operator_metatypes_for_matching(self) -> List[Type[OperatorMetatype]]:
7979
pass
8080

8181
@staticmethod
82-
def get_path_to_hw_config(hw_config_type: HWConfigType):
82+
def get_path_to_hw_config(hw_config_type: HWConfigType) -> str:
8383
return "/".join(
8484
[NNCF_PACKAGE_ROOT_DIR, HW_CONFIG_RELATIVE_DIR, HWConfig.TYPE_TO_CONF_NAME_DICT[hw_config_type]]
8585
)
8686

8787
@classmethod
88-
def from_dict(cls, dct: dict):
88+
def from_dict(cls, dct: Dict[str, Any]) -> "HWConfig":
8989
hw_config = cls()
9090
hw_config.target_device = dct["target_device"]
9191

@@ -104,7 +104,7 @@ def from_dict(cls, dct: dict):
104104
for algorithm_name in op_dict:
105105
if algorithm_name not in hw_config.registered_algorithm_configs:
106106
continue
107-
tmp_config = {}
107+
tmp_config: Dict[str, List[Dict[str, Any]]] = {}
108108
for algo_and_op_specific_field_name, algorithm_configs in op_dict[algorithm_name].items():
109109
if not isinstance(algorithm_configs, list):
110110
algorithm_configs = [algorithm_configs]
@@ -129,30 +129,30 @@ def from_dict(cls, dct: dict):
129129
return hw_config
130130

131131
@classmethod
132-
def from_json(cls, path):
132+
def from_json(cls: type["HWConfig"], path: str) -> List[Dict[str, Any]]:
133133
file_path = Path(path).resolve()
134134
with safe_open(file_path) as f:
135135
json_config = json.load(f, object_pairs_hook=OrderedDict)
136136
return cls.from_dict(json_config)
137137

138138
@staticmethod
139-
def get_quantization_mode_from_config_value(str_val: str):
139+
def get_quantization_mode_from_config_value(str_val: str) -> str:
140140
if str_val == "symmetric":
141141
return QuantizationMode.SYMMETRIC
142142
if str_val == "asymmetric":
143143
return QuantizationMode.ASYMMETRIC
144144
raise nncf.ValidationError("Invalid quantization type specified in HW config")
145145

146146
@staticmethod
147-
def get_is_per_channel_from_config_value(str_val: str):
147+
def get_is_per_channel_from_config_value(str_val: str) -> bool:
148148
if str_val == "perchannel":
149149
return True
150150
if str_val == "pertensor":
151151
return False
152152
raise nncf.ValidationError("Invalid quantization granularity specified in HW config")
153153

154154
@staticmethod
155-
def get_qconf_from_hw_config_subdict(quantization_subdict: Dict):
155+
def get_qconf_from_hw_config_subdict(quantization_subdict: Dict[str, Any]) -> QuantizerConfig:
156156
bits = quantization_subdict["bits"]
157157
mode = HWConfig.get_quantization_mode_from_config_value(quantization_subdict["mode"])
158158
is_per_channel = HWConfig.get_is_per_channel_from_config_value(quantization_subdict["granularity"])
@@ -181,20 +181,22 @@ def get_qconf_from_hw_config_subdict(quantization_subdict: Dict):
181181
)
182182

183183
@staticmethod
184-
def is_qconf_list_corresponding_to_unspecified_op(qconf_list: Optional[List[QuantizerConfig]]):
184+
def is_qconf_list_corresponding_to_unspecified_op(qconf_list: Optional[List[QuantizerConfig]]) -> bool:
185185
return qconf_list is None
186186

187187
@staticmethod
188-
def is_wildcard_quantization(qconf_list: Optional[List[QuantizerConfig]]):
188+
def is_wildcard_quantization(qconf_list: Optional[List[QuantizerConfig]]) -> bool:
189189
# Corresponds to an op itself being specified in the HW config, but having no associated quantization
190190
# configs specified
191191
return qconf_list is not None and len(qconf_list) == 0
192192

193193
def get_metatype_vs_quantizer_configs_map(
194-
self, for_weights=False
194+
self, for_weights: bool = False
195195
) -> Dict[Type[OperatorMetatype], Optional[List[QuantizerConfig]]]:
196196
# 'None' for ops unspecified in HW config, empty list for wildcard quantization ops
197-
retval = {k: None for k in self._get_available_operator_metatypes_for_matching()}
197+
retval: Dict[Type[OperatorMetatype], Optional[List[QuantizerConfig]]] = {
198+
k: None for k in self._get_available_operator_metatypes_for_matching()
199+
}
198200
config_key = "weights" if for_weights else "activations"
199201
for op_dict in self:
200202
hw_config_op_name = op_dict["type"]

nncf/common/quantization/structs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from copy import deepcopy
1313
from enum import Enum
14-
from typing import Any, Dict, List, Optional
14+
from typing import Any, Dict, List, Optional, Union
1515

1616
import nncf
1717
from nncf.common.graph import NNCFNode
@@ -45,7 +45,7 @@ class QuantizerConfig:
4545
def __init__(
4646
self,
4747
num_bits: int = QUANTIZATION_BITS,
48-
mode: QuantizationScheme = QuantizationScheme.SYMMETRIC,
48+
mode: Union[QuantizationScheme, str] = QuantizationScheme.SYMMETRIC, # TODO(AlexanderDokuchaev): use enum
4949
signedness_to_force: Optional[bool] = None,
5050
per_channel: bool = QUANTIZATION_PER_CHANNEL,
5151
):

nncf/common/utils/helpers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
import itertools
1313
import os
1414
import os.path as osp
15-
import pathlib
16-
from typing import Any, Dict, Hashable, Iterable, List, Optional, Union
15+
from pathlib import Path
16+
from typing import Any, Dict, Hashable, Iterable, List, Optional, TypeVar, Union
1717

1818
from tabulate import tabulate
1919

2020
from nncf.common.utils.os import is_windows
2121

22+
TKey = TypeVar("TKey", bound=Hashable)
23+
2224

2325
def create_table(
2426
header: List[str],
@@ -44,7 +46,7 @@ def create_table(
4446
return tabulate(tabular_data=rows, headers=header, tablefmt=table_fmt, maxcolwidths=max_col_widths, floatfmt=".3f")
4547

4648

47-
def configure_accuracy_aware_paths(log_dir: Union[str, pathlib.Path]) -> Union[str, pathlib.Path]:
49+
def configure_accuracy_aware_paths(log_dir: Union[str, Path]) -> Union[str, Path]:
4850
"""
4951
Create a subdirectory inside of the passed log directory
5052
to save checkpoints from the accuracy-aware training loop to.
@@ -59,7 +61,7 @@ def configure_accuracy_aware_paths(log_dir: Union[str, pathlib.Path]) -> Union[s
5961
return acc_aware_log_dir
6062

6163

62-
def product_dict(d: Dict[Hashable, List[str]]) -> Iterable[Dict[Hashable, str]]:
64+
def product_dict(d: Dict[TKey, List[Any]]) -> Iterable[Dict[TKey, Any]]:
6365
"""
6466
Generates dicts which enumerate the options for keys given in the input dict;
6567
options are represented by list values in the input dict.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ exclude = [
109109
"nncf/common/composite_compression.py",
110110
"nncf/common/compression.py",
111111
"nncf/common/deprecation.py",
112-
"nncf/common/hardware/config.py",
113112
"nncf/common/logging/progress_bar.py",
114113
"nncf/common/logging/track_progress.py",
115114
"nncf/common/pruning/clusterization.py",

0 commit comments

Comments
 (0)