Skip to content

Commit 6031ccc

Browse files
authored
[ONNX] Fix weight quantization for GroupConvolution (#3126)
### Reason for changes Not quantized weights for GroupConv ### Related tickets 158085 ### Tests ptq perf run 84
1 parent 5189aab commit 6031ccc

File tree

8 files changed

+40
-73
lines changed

8 files changed

+40
-73
lines changed

nncf/onnx/graph/metatypes/groups.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# limitations under the License.
1111

1212
from nncf.onnx.graph.metatypes import onnx_metatypes
13+
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype
14+
from nncf.onnx.graph.metatypes.onnx_metatypes import get_operator_metatypes
1315

1416
QUANTIZE_AGNOSTIC_OPERATIONS = [
1517
onnx_metatypes.ONNXGlobalMaxPoolMetatype,
@@ -67,14 +69,19 @@
6769
onnx_metatypes.ONNXMinimumMetatype,
6870
]
6971

70-
7172
CONSTANT_WEIGHT_LAYER_METATYPES = [
72-
onnx_metatypes.ONNXConvolutionMetatype,
73-
onnx_metatypes.ONNXDepthwiseConvolutionMetatype,
74-
onnx_metatypes.ONNXConvolutionTransposeMetatype,
75-
onnx_metatypes.ONNXEmbeddingMetatype,
73+
metatype
74+
for metatype in get_operator_metatypes()
75+
if issubclass(metatype, ONNXOpWithWeightsMetatype) and metatype.weight_port_ids
7676
]
7777

78+
POSSIBLE_WEIGHT_LAYER_METATYPES = [
79+
metatype
80+
for metatype in get_operator_metatypes()
81+
if issubclass(metatype, ONNXOpWithWeightsMetatype) and metatype.possible_weight_ports
82+
]
83+
84+
OPERATIONS_WITH_WEIGHTS = list(set().union(CONSTANT_WEIGHT_LAYER_METATYPES, POSSIBLE_WEIGHT_LAYER_METATYPES))
7885

7986
LINEAR_OPERATIONS = [
8087
onnx_metatypes.ONNXConvolutionMetatype,
@@ -124,11 +131,6 @@
124131
onnx_metatypes.ONNXMeanMetatype,
125132
]
126133

127-
OPERATIONS_WITH_WEIGHTS = [
128-
*CONSTANT_WEIGHT_LAYER_METATYPES,
129-
*MATMUL_METATYPES,
130-
]
131-
132134

133135
BATCH_NORMALIZATION_OPERATIONS = [
134136
onnx_metatypes.ONNXBatchNormMetatype,

nncf/onnx/graph/metatypes/onnx_metatypes.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ def determine_subtype(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> Opti
5858
class ONNXOpWithWeightsMetatype(ONNXOpMetatype):
5959
"""
6060
Metatype which could have weights.
61-
62-
:param weight_channel_axis: Axis for weight per-channel quantization, meaning the number of output filters.
63-
:param weight_port_ids: Input ports of the node's weight.
64-
If the value is None the weight_port_id should be determined dynamically.
65-
:param bias_port_id: Input port of the node's bias.
66-
If the value is None it means that the Metatype does not have bias.
61+
:param weight_channel_axis: Axis for weight per-channel quantization.
62+
:param weight_port_ids: Constant input ports of the node's weight. Defaults to an empty list.
63+
:param bias_port_id: Input port of the node's bias. If the value is None,
64+
it means that the Metatype does not have bias. Defaults to None.
65+
:param possible_weight_ports: Input ports on which weight could be laid. Defaults to an empty list.
6766
"""
6867

6968
weight_channel_axis: int
70-
weight_port_ids: Optional[List[int]] = None
69+
weight_port_ids: List[int] = []
7170
bias_port_id: Optional[int] = None
71+
possible_weight_ports: List[int] = []
7272

7373

7474
@ONNX_OPERATION_METATYPES.register(is_subtype=True)
@@ -131,19 +131,17 @@ class ONNXGemmMetatype(ONNXOpWithWeightsMetatype):
131131
op_names = ["Gemm"]
132132
hw_config_names = [HWConfigOpName.MATMUL]
133133
weight_channel_axis = -1 # For port_id=1
134-
weight_port_ids = None
135134
bias_port_id = 2
136135
possible_weight_ports = [0, 1]
137136
output_channel_axis = -1
138137

139138

140139
@ONNX_OPERATION_METATYPES.register()
141-
class ONNXMatMulMetatype(ONNXOpMetatype):
140+
class ONNXMatMulMetatype(ONNXOpWithWeightsMetatype):
142141
name = "MatMulOp"
143142
op_names = ["MatMul"]
144143
hw_config_names = [HWConfigOpName.MATMUL]
145144
weight_channel_axis = -1 # For port_id=1
146-
weight_port_ids = None
147145
bias_port_id = 2
148146
possible_weight_ports = [0, 1]
149147
output_channel_axis = -1
@@ -454,7 +452,7 @@ class ONNXReciprocalMetatype(ONNXOpMetatype):
454452

455453

456454
@ONNX_OPERATION_METATYPES.register(is_subtype=True)
457-
class ONNXEmbeddingMetatype(ONNXOpMetatype):
455+
class ONNXEmbeddingMetatype(ONNXOpWithWeightsMetatype):
458456
name = "EmbeddingOp"
459457
hw_config_names = [HWConfigOpName.EMBEDDING]
460458
weight_port_ids = [0]

nncf/onnx/graph/nncf_graph_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from nncf.common.graph.operator_metatypes import InputNoopMetatype
2424
from nncf.common.graph.operator_metatypes import OutputNoopMetatype
2525
from nncf.onnx.graph.metatypes.groups import CONSTANT_WEIGHT_LAYER_METATYPES
26-
from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES
2726
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS
27+
from nncf.onnx.graph.metatypes.groups import POSSIBLE_WEIGHT_LAYER_METATYPES
2828
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype
2929
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype
3030
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype
@@ -95,7 +95,7 @@ def get_possible_weight_port_ids(metatype: ONNXOpMetatype) -> List[int]:
9595
:param metatype: Metatype.
9696
:return: Port ids.
9797
"""
98-
if metatype in MATMUL_METATYPES:
98+
if metatype in POSSIBLE_WEIGHT_LAYER_METATYPES:
9999
return metatype.possible_weight_ports
100100
return []
101101

nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: onnx.ModelProto
8181
@staticmethod
8282
def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]:
8383
activation_port = 0
84-
85-
if hasattr(node.metatype, "possible_weight_ports"):
84+
if node.metatype.possible_weight_ports:
8685
activation_ports = deepcopy(node.metatype.possible_weight_ports)
8786
for weight_port in node.layer_attributes.weight_attrs:
8887
activation_ports.remove(weight_port)

tests/post_training/data/ptq_reference_data.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ torchvision/resnet18_backend_FX_TORCH:
4141
torchvision/mobilenet_v3_small_BC_backend_FP32:
4242
metric_value: 0.6766
4343
torchvision/mobilenet_v3_small_BC_backend_OV:
44-
metric_value: 0.6669
44+
metric_value: 0.6681
4545
torchvision/mobilenet_v3_small_BC_backend_ONNX:
4646
metric_value: 0.6679
4747
torchvision/mobilenet_v3_small_BC_backend_FX_TORCH:
@@ -103,7 +103,7 @@ timm/dpn68_backend_CUDA_TORCH:
103103
timm/dpn68_backend_FP32:
104104
metric_value: 0.76342
105105
timm/dpn68_backend_ONNX:
106-
metric_value: 0.75906
106+
metric_value: 0.7592
107107
timm/dpn68_backend_OV:
108108
metric_value: 0.75972
109109
timm/dpn68_backend_TORCH:
@@ -201,7 +201,7 @@ timm/regnetx_002_backend_CUDA_TORCH:
201201
timm/regnetx_002_backend_FP32:
202202
metric_value: 0.68756
203203
timm/regnetx_002_backend_ONNX:
204-
metric_value: 0.6848
204+
metric_value: 0.6854
205205
timm/regnetx_002_backend_OV:
206206
metric_value: 0.6852
207207
timm/regnetx_002_backend_TORCH:
@@ -211,7 +211,7 @@ timm/resnest14d_backend_CUDA_TORCH:
211211
timm/resnest14d_backend_FP32:
212212
metric_value: 0.75516
213213
timm/resnest14d_backend_ONNX:
214-
metric_value: 0.75428
214+
metric_value: 0.7538
215215
timm/resnest14d_backend_OV:
216216
metric_value: 0.75
217217
timm/resnest14d_backend_TORCH:
@@ -253,7 +253,7 @@ timm/visformer_small_backend_CUDA_TORCH:
253253
timm/visformer_small_backend_FP32:
254254
metric_value: 0.82098
255255
timm/visformer_small_backend_ONNX:
256-
metric_value: 0.81562
256+
metric_value: 0.8160
257257
timm/visformer_small_backend_OV:
258258
metric_value: 0.81674
259259
timm/visformer_small_backend_TORCH:

tests/post_training/data/ptq_reference_data_2024.5.yaml

Lines changed: 0 additions & 2 deletions
This file was deleted.

tests/post_training/data/wc_reference_data.yaml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,44 @@ tinyllama_data_aware_backend_OV:
77
num_int4: 94
88
num_int8: 124
99
tinyllama_data_aware_awq_stateful_backend_OV:
10-
metric_value: 0.85571
10+
metric_value: 0.85616
1111
num_int4: 94
1212
num_int8: 124
1313
tinyllama_data_aware_awq_scale_estimation_backend_OV:
14-
metric_value: 0.86355
14+
metric_value: 0.85502
1515
num_int4: 94
1616
num_int8: 124
1717
tinyllama_data_aware_awq_scale_estimation_stateful_backend_OV:
18-
metric_value: 0.86355
18+
metric_value: 0.85502
1919
num_int4: 94
2020
num_int8: 124
2121
tinyllama_int8_data_free_backend_TORCH:
2222
metric_value: 0.95624
2323
num_int4: 0
2424
num_int8: 312
2525
tinyllama_data_aware_gptq_scale_estimation_stateful_backend_OV:
26-
metric_value: 0.86697
26+
metric_value: 0.86503
2727
num_int4: 94
2828
num_int8: 124
2929
metrics_xfail_reason: "Issue-148819"
3030
tinyllama_scale_estimation_per_channel_backend_OV:
31-
metric_value: 0.80798
31+
metric_value: 0.81389
3232
num_int4: 188
3333
num_int8: 124
3434
tinyllama_data_aware_lora_stateful_backend_OV:
3535
metric_value: 0.83446
3636
num_int4: 94
3737
num_int8: 500
3838
tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV:
39-
metric_value: 0.87132
39+
metric_value: 0.88663
4040
num_int4: 11
4141
num_int8: 290
4242
metrics_xfail_reason: "Issue-148819"
4343
tinyllama_awq_backup_mode_none_backend_OV:
44-
metric_value: 0.85679
44+
metric_value: 0.84783
4545
num_int4: 208
4646
num_int8: 0
47+
tinyllama_int4_data_free_backend_TORCH:
48+
metric_value: 0.73873
49+
num_int4: 114
50+
num_int8: 84

tests/post_training/data/wc_reference_data_2024.5.yaml

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)