Skip to content

Commit b83c10b

Browse files
comments
1 parent 39971e0 commit b83c10b

File tree

4 files changed

+4
-34
lines changed

4 files changed

+4
-34
lines changed

nncf/experimental/torch2/function_hook/extractor.py

-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def apply_args_to_kwargs(
6161
:param args: The positional arguments.
6262
:param kwargs: The keyword arguments.
6363
:param indexed_args: The list of pairs of indexes and names.
64-
6564
:return: A dictionary of keyword arguments with the applied arguments and keyword arguments.
6665
"""
6766
args_dict: Dict[str, Any] = dict()
@@ -77,7 +76,6 @@ def apply_args_to_kwargs(
7776
def extract_bn(model: nn.Module, graph: PTNNCFGraph, node: NNCFNode) -> ExtractedFunc:
7877
"""
7978
Extract batch_norm operation.
80-
If source modules inhered from nn.BatchNorm1d, nn.BatchNorm2d, or nn.BatchNorm3d return torch BatchNorm module.
8179
8280
:param model: Source model.
8381
:param graph: Graph of source model.

nncf/experimental/torch2/model_transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ class PT2ModelTransformer(ModelTransformer[GraphModelWrapper]):
3838
def __init__(self, model: GraphModelWrapper):
3939
super().__init__(model)
4040

41-
self._command_transformation_ordered_pairs: TRANSFORMATION_PAIRS = [
41+
self._command_transformation_ordered_pairs: TRANSFORMATION_PAIRS = (
4242
(PT2InsertionCommand, self._apply_insertion_transformations),
4343
(PTBiasCorrectionCommand, self._apply_bias_correction_transformations),
44-
]
44+
)
4545

4646
def transform(self, transformation_layout: TransformationLayout) -> GraphModelWrapper:
4747
"""

tests/torch/ptq/test_fast_bias_correction.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def check_bias(model: NNCFNetwork, ref_bias: list):
5555
for node in nncf_graph.get_all_nodes():
5656
if not is_node_with_fused_bias(node, nncf_graph):
5757
continue
58-
bias_value = get_fused_bias_value(node, nncf_graph, model)
58+
bias_value = get_fused_bias_value(node, nncf_graph, model).cpu()
5959
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
6060
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}"
6161
return
@@ -77,17 +77,3 @@ def backend_specific_model(model: bool, tmp_dir: str):
7777
@staticmethod
7878
def fn_to_type(tensor):
7979
return torch.Tensor(tensor).cuda()
80-
81-
@staticmethod
82-
def check_bias(model: NNCFNetwork, ref_bias: list):
83-
ref_bias = torch.Tensor(ref_bias)
84-
nncf_graph = NNCFGraphFactory.create(model)
85-
for node in nncf_graph.get_all_nodes():
86-
if not is_node_with_fused_bias(node, nncf_graph):
87-
continue
88-
bias_value = get_fused_bias_value(node, nncf_graph, model).cpu()
89-
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
90-
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}"
91-
return
92-
msg = "Not found node with bias"
93-
raise ValueError(msg)

tests/torch2/function_hook/quantization/test_fast_bias_correction.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def check_bias(model: GraphModelWrapper, ref_bias: list):
5454
for node in nncf_graph.get_all_nodes():
5555
if not is_node_with_fused_bias(node, nncf_graph):
5656
continue
57-
bias_value = get_fused_bias_value(node, nncf_graph, model.model)
57+
bias_value = get_fused_bias_value(node, nncf_graph, model.model).cpu()
5858
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
5959
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}"
6060
return
@@ -76,17 +76,3 @@ def backend_specific_model(model: bool, tmp_dir: str):
7676
@staticmethod
7777
def fn_to_type(tensor):
7878
return torch.Tensor(tensor).cuda()
79-
80-
@staticmethod
81-
def check_bias(model: GraphModelWrapper, ref_bias: list):
82-
ref_bias = torch.Tensor(ref_bias)
83-
nncf_graph = model.get_graph()
84-
for node in nncf_graph.get_all_nodes():
85-
if not is_node_with_fused_bias(node, nncf_graph):
86-
continue
87-
bias_value = get_fused_bias_value(node, nncf_graph, model.model).cpu()
88-
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
89-
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}"
90-
return
91-
msg = "Not found node with bias"
92-
raise ValueError(msg)

0 commit comments

Comments
 (0)