Skip to content

Commit fec8c6e

Browse files
[mypy] nncf/config (#3175)
### Changes Enable mypy check for `nncf/config`
1 parent 6a05971 commit fec8c6e

File tree

9 files changed

+42
-35
lines changed

9 files changed

+42
-35
lines changed

nncf/common/accuracy_aware_training/training_loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def __init__(
277277
super().__init__(compression_controller)
278278
accuracy_aware_training_params = extract_accuracy_aware_training_params(nncf_config)
279279
runner_factory = EarlyExitTrainingRunnerCreator(
280-
accuracy_aware_training_params, # type: ignore
280+
accuracy_aware_training_params,
281281
compression_controller,
282282
uncompressed_model_accuracy,
283283
verbose,
@@ -330,7 +330,7 @@ def __init__(
330330

331331
accuracy_aware_training_params = extract_accuracy_aware_training_params(nncf_config)
332332
runner_factory = AdaptiveCompressionLevelTrainingRunnerCreator(
333-
accuracy_aware_training_params, # type: ignore
333+
accuracy_aware_training_params,
334334
self.adaptive_controller,
335335
uncompressed_model_accuracy,
336336
verbose,

nncf/config/config.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
from copy import deepcopy
1313
from pathlib import Path
14-
from typing import Dict, List, Optional, Type
14+
from typing import Any, Dict, List, Optional, Type
1515

1616
import jsonschema
17-
import jstyleson as json
17+
import jstyleson as json # type: ignore
1818

1919
import nncf
2020
from nncf.common.logging import nncf_logger
@@ -28,27 +28,27 @@
2828

2929

3030
@api(canonical_alias="nncf.NNCFConfig")
31-
class NNCFConfig(dict):
31+
class NNCFConfig(dict[str, Any]):
3232
"""Contains the configuration parameters required for NNCF to apply the selected algorithms.
3333
3434
This is a regular dictionary object extended with some utility functions, such as the ability to attach well-defined
3535
structures to pass non-serializable objects as parameters. It is primarily built from a .json file, or from a
3636
Python JSON-like dictionary - both data types will be checked against a JSONSchema. See the definition of the
3737
schema at https://openvinotoolkit.github.io/nncf/schema/, or by calling NNCFConfig.schema()."""
3838

39-
def __init__(self, *args, **kwargs):
39+
def __init__(self, *args: Any, **kwargs: Any) -> None:
4040
super().__init__(*args, **kwargs)
4141
self.__nncf_extra_structs: Dict[str, NNCFExtraConfigStruct] = {}
4242

4343
@classmethod
44-
def from_dict(cls, nncf_dict: Dict) -> "NNCFConfig":
44+
def from_dict(cls, nncf_dict: Dict[str, Any]) -> "NNCFConfig":
4545
"""
4646
Load NNCF config from a Python dictionary. The dict must contain only JSON-supported primitives.
4747
4848
:param nncf_dict: A Python dict with the JSON-style configuration for NNCF.
4949
"""
5050

51-
NNCFConfig.validate(nncf_dict)
51+
cls.validate(nncf_dict)
5252
return cls(deepcopy(nncf_dict))
5353

5454
@classmethod
@@ -63,7 +63,7 @@ def from_json(cls, path: str) -> "NNCFConfig":
6363
loaded_json = json.load(f)
6464
return cls.from_dict(loaded_json)
6565

66-
def register_extra_structs(self, struct_list: List[NNCFExtraConfigStruct]):
66+
def register_extra_structs(self, struct_list: List[NNCFExtraConfigStruct]) -> None:
6767
"""
6868
Attach the supplied list of extra configuration structures to this configuration object.
6969
@@ -78,7 +78,7 @@ def register_extra_structs(self, struct_list: List[NNCFExtraConfigStruct]):
7878
def get_extra_struct(self, struct_cls: Type[NNCFExtraConfigStruct]) -> NNCFExtraConfigStruct:
7979
return self.__nncf_extra_structs[struct_cls.get_id()]
8080

81-
def has_extra_struct(self, struct_cls: Type[NNCFExtraConfigStruct]) -> NNCFExtraConfigStruct:
81+
def has_extra_struct(self, struct_cls: Type[NNCFExtraConfigStruct]) -> bool:
8282
return struct_cls.get_id() in self.__nncf_extra_structs
8383

8484
def get_all_extra_structs(self) -> List[NNCFExtraConfigStruct]:
@@ -108,7 +108,7 @@ def get_redefinable_global_param_value_for_algo(self, param_name: str, algo_name
108108
return param
109109

110110
@staticmethod
111-
def schema() -> Dict:
111+
def schema() -> Dict[str, Any]:
112112
"""
113113
Returns the JSONSchema against which the input data formats (.json or Python dict) are validated.
114114
"""
@@ -124,15 +124,15 @@ def _is_path_to_algorithm_name(path_parts: List[str]) -> bool:
124124
)
125125

126126
@staticmethod
127-
def validate(loaded_json):
127+
def validate(loaded_json: Dict[str, Any]) -> None:
128128
try:
129129
jsonschema.validate(loaded_json, NNCFConfig.schema())
130130
except jsonschema.ValidationError as e:
131131
nncf_logger.error("Invalid NNCF config supplied!")
132132
absolute_path_parts = [str(x) for x in e.absolute_path]
133133
if not NNCFConfig._is_path_to_algorithm_name(absolute_path_parts):
134134
e.message += f"\nRefer to the NNCF config schema documentation at {SCHEMA_VISUALIZATION_URL}"
135-
e.schema = "*schema too long for stdout display*"
135+
e.schema = "*schema too long for stdout display*" # type: ignore[assignment]
136136
raise e
137137

138138
# Need to make the error more algo-specific in case the config was so bad that no

nncf/config/extractors.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def extract_algorithm_names(config: NNCFConfig) -> List[str]:
3131
return retval
3232

3333

34-
def extract_algo_specific_config(config: NNCFConfig, algo_name_to_match: str) -> Dict:
34+
def extract_algo_specific_config(config: NNCFConfig, algo_name_to_match: str) -> Dict[str, Any]:
3535
"""
3636
Extracts a .json sub-dictionary for a given compression algorithm from the
3737
common NNCFConfig.
@@ -77,7 +77,7 @@ def extract_algo_specific_config(config: NNCFConfig, algo_name_to_match: str) ->
7777
return next(iter(matches))
7878

7979

80-
def extract_range_init_params(config: NNCFConfig, algorithm_name: str = "quantization") -> Optional[Dict[str, object]]:
80+
def extract_range_init_params(config: NNCFConfig, algorithm_name: str = "quantization") -> Optional[Dict[str, Any]]:
8181
"""
8282
Extracts parameters of the quantization range initialization algorithm from the
8383
compression algorithm NNCFconfig.
@@ -90,7 +90,6 @@ def extract_range_init_params(config: NNCFConfig, algorithm_name: str = "quantiz
9090
algo_config = extract_algo_specific_config(config, algorithm_name)
9191
init_range_config_dict_or_list = algo_config.get("initializer", {}).get("range", {})
9292

93-
range_init_args = None
9493
try:
9594
range_init_args = config.get_extra_struct(QuantizationRangeInitArgs)
9695
except KeyError:
@@ -120,7 +119,7 @@ def extract_range_init_params(config: NNCFConfig, algorithm_name: str = "quantiz
120119

121120
if max_num_init_samples == 0:
122121
return None
123-
if range_init_args is None:
122+
if not isinstance(range_init_args, QuantizationRangeInitArgs):
124123
raise ValueError(
125124
"Should run range initialization as specified via config,"
126125
"but the initializing data loader is not provided as an extra struct. "
@@ -161,7 +160,7 @@ class BNAdaptDataLoaderNotFoundError(RuntimeError):
161160
pass
162161

163162

164-
def get_bn_adapt_algo_kwargs(nncf_config: NNCFConfig, params: Dict[str, Any]) -> Dict[str, Any]:
163+
def get_bn_adapt_algo_kwargs(nncf_config: NNCFConfig, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
165164
num_bn_adaptation_samples = params.get("num_bn_adaptation_samples", NUM_BN_ADAPTATION_SAMPLES)
166165

167166
if num_bn_adaptation_samples == 0:
@@ -175,6 +174,11 @@ def get_bn_adapt_algo_kwargs(nncf_config: NNCFConfig, params: Dict[str, Any]) ->
175174
"because the data loader is not provided as an extra struct. Refer to the "
176175
"`NNCFConfig.register_extra_structs` method and the `BNAdaptationInitArgs` class."
177176
) from None
177+
178+
if not isinstance(args, BNAdaptationInitArgs):
179+
raise BNAdaptDataLoaderNotFoundError(
180+
"The extra struct for batch-norm adaptation must be an instance of the BNAdaptationInitArgs class."
181+
)
178182
params = {
179183
"num_bn_adaptation_samples": num_bn_adaptation_samples,
180184
"data_loader": args.data_loader,
@@ -183,7 +187,7 @@ def get_bn_adapt_algo_kwargs(nncf_config: NNCFConfig, params: Dict[str, Any]) ->
183187
return params
184188

185189

186-
def extract_accuracy_aware_training_params(config: NNCFConfig) -> Dict[str, object]:
190+
def extract_accuracy_aware_training_params(config: NNCFConfig) -> Dict[str, Any]:
187191
"""
188192
Extracts accuracy aware training parameters from NNCFConfig.
189193
@@ -196,7 +200,7 @@ class NNCFAlgorithmNames:
196200
FILTER_PRUNING = "filter_pruning"
197201
SPARSITY = ["rb_sparsity", "magnitude_sparsity", "const_sparsity"]
198202

199-
def validate_accuracy_aware_schema(config: NNCFConfig, params: Dict[str, object]):
203+
def validate_accuracy_aware_schema(config: NNCFConfig, params: Dict[str, Any]) -> None:
200204
from nncf.common.accuracy_aware_training import AccuracyAwareTrainingMode
201205

202206
if params["mode"] == AccuracyAwareTrainingMode.EARLY_EXIT:

nncf/config/schema.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
import logging
13-
from typing import Dict
13+
from typing import Any, Dict
1414

1515
import jsonschema
1616

@@ -148,7 +148,9 @@
148148
}
149149

150150

151-
def validate_single_compression_algo_schema(single_compression_algo_dict: Dict, ref_vs_algo_schema: Dict):
151+
def validate_single_compression_algo_schema(
152+
single_compression_algo_dict: Dict[str, Any], ref_vs_algo_schema: Dict[str, Any]
153+
) -> None:
152154
"""single_compression_algo_dict must conform to BASIC_COMPRESSION_ALGO_SCHEMA (and possibly has other
153155
algo-specific properties"""
154156
algo_name = single_compression_algo_dict["algorithm"]
@@ -172,7 +174,7 @@ def validate_single_compression_algo_schema(single_compression_algo_dict: Dict,
172174
raise e
173175

174176

175-
def validate_accuracy_aware_training_schema(single_compression_algo_dict: Dict):
177+
def validate_accuracy_aware_training_schema(single_compression_algo_dict: Dict[str, Any]) -> None:
176178
"""
177179
Checks accuracy_aware_training section.
178180
"""

nncf/config/schemata/algo/quantization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
PER_LAYER_RANGE_INIT_CONFIG_PROPERTIES = {
161161
"type": "object",
162162
"properties": {
163-
**BASIC_RANGE_INIT_CONFIG_PROPERTIES["properties"],
163+
**BASIC_RANGE_INIT_CONFIG_PROPERTIES["properties"], # type: ignore[dict-item]
164164
**SCOPING_PROPERTIES,
165165
"target_quantizer_group": with_attributes(
166166
STRING,
@@ -338,7 +338,7 @@
338338
"parameters have a better chance to get fine-tuned to values that result in good accuracy.",
339339
"properties": {
340340
"batchnorm_adaptation": BATCHNORM_ADAPTATION_SCHEMA,
341-
**RANGE_INIT_CONFIG_PROPERTIES["initializer"]["properties"],
341+
**RANGE_INIT_CONFIG_PROPERTIES["initializer"]["properties"], # type: ignore[dict-item]
342342
"precision": PRECISION_INITIALIZER_SCHEMA,
343343
},
344344
"additionalProperties": False,

nncf/config/schemata/basic.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11-
from typing import Dict
11+
from typing import Any, Dict, List, Optional
1212

1313
NUMBER = {"type": "number"}
1414
STRING = {"type": "string"}
@@ -17,22 +17,22 @@
1717
ARRAY_OF_STRINGS = {"type": "array", "items": STRING}
1818

1919

20-
def annotated_enum(names_vs_description: Dict[str, str]) -> Dict:
20+
def annotated_enum(names_vs_description: Dict[str, str]) -> Dict[str, List[Dict[str, str]]]:
2121
retval_list = []
2222
for name, descr in names_vs_description.items():
2323
retval_list.append({"const": name, "title": name, "description": descr})
2424
return {"oneOf": retval_list}
2525

2626

27-
def make_string_or_array_of_strings_schema(addtl_dict_entries: Dict = None) -> Dict:
27+
def make_string_or_array_of_strings_schema(addtl_dict_entries: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
2828
if addtl_dict_entries is None:
2929
addtl_dict_entries = {}
3030
retval = {"type": ["array", "string"], "items": {"type": "string"}}
3131
retval.update(addtl_dict_entries)
3232
return retval
3333

3434

35-
def make_object_or_array_of_objects_schema(single_object_schema: Dict = None) -> Dict:
35+
def make_object_or_array_of_objects_schema(single_object_schema: Dict[str, Any]) -> Dict[str, Any]:
3636
retval = {
3737
"oneOf": [
3838
{
@@ -45,6 +45,6 @@ def make_object_or_array_of_objects_schema(single_object_schema: Dict = None) ->
4545
return retval
4646

4747

48-
def with_attributes(schema: Dict, **kwargs) -> Dict:
48+
def with_attributes(schema: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
4949
retval = {**schema, **kwargs}
5050
return retval

nncf/config/schemata/experimental_schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# Experimental Quantization
3636
########################################################################################################################
3737
EXPERIMENTAL_QUANTIZATION_SCHEMA = copy.deepcopy(QUANTIZATION_SCHEMA)
38-
EXPERIMENTAL_QUANTIZATION_SCHEMA["properties"]["algorithm"]["const"] = EXPERIMENTAL_QUANTIZATION_ALGO_NAME_IN_CONFIG
38+
EXPERIMENTAL_QUANTIZATION_SCHEMA["properties"]["algorithm"]["const"] = EXPERIMENTAL_QUANTIZATION_ALGO_NAME_IN_CONFIG # type: ignore[index]
3939

4040
########################################################################################################################
4141
# BootstrapNAS

nncf/config/structures.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"""
1212
Structures for passing live Python objects into NNCF algorithms.
1313
"""
14-
from typing import Callable, Optional
14+
from typing import Any, Callable, Optional
1515

1616
from nncf.common.initialization.dataloader import NNCFDataLoader
1717
from nncf.common.utils.api_marker import api
@@ -47,7 +47,7 @@ def data_loader(self) -> NNCFDataLoader:
4747
return self._data_loader
4848

4949
@property
50-
def device(self) -> str:
50+
def device(self) -> Optional[str]:
5151
return self._device
5252

5353
@classmethod
@@ -74,7 +74,7 @@ def data_loader(self) -> NNCFDataLoader:
7474
return self._data_loader
7575

7676
@property
77-
def device(self) -> str:
77+
def device(self) -> Optional[str]:
7878
return self._device
7979

8080
@classmethod
@@ -91,7 +91,7 @@ class ModelEvaluationArgs(NNCFExtraConfigStruct):
9191
the evaluation split of the dataset corresponding to the model.
9292
"""
9393

94-
def __init__(self, eval_fn: Callable):
94+
def __init__(self, eval_fn: Callable[..., Any]):
9595
self.eval_fn = eval_fn
9696

9797
@classmethod

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ strict = true
9393
# https://github.com/hauntsaninja/no_implicit_optional
9494
implicit_optional = true
9595
files = [
96+
"nncf/config",
9697
"nncf/common/sparsity",
9798
"nncf/common/graph",
9899
"nncf/common/accuracy_aware_training/",

0 commit comments

Comments
 (0)