Skip to content

Commit 44103b4

Browse files
committed
revert breaking changes
1 parent 5991aef commit 44103b4

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

nncf/common/graph/operator_metatypes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self, name: str):
7878
"""
7979
super().__init__(name)
8080
self._op_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {}
81+
self._func_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {}
8182

8283
def register(self, name: Optional[str] = None, is_subtype: bool = False) -> Callable[..., Type[OperatorMetatype]]:
8384
"""
@@ -137,6 +138,18 @@ def get_operator_metatype_by_op_name(self, op_name: str) -> Type[OperatorMetatyp
137138
return UnknownMetatype
138139
return self._op_name_to_op_meta_dict[op_name]
139140

141+
def get_operator_metatype_by_func(self, func_name: str) -> Type[OperatorMetatype]:
142+
"""
143+
Returns the operator metatype by function name.
144+
145+
:param func_name: The function name.
146+
:return: The operator metatype.
147+
"""
148+
if func_name not in self._func_name_to_op_meta_dict:
149+
return UnknownMetatype
150+
obj = self._func_name_to_op_meta_dict[func_name]
151+
return obj
152+
140153

141154
NOOP_METATYPES = Registry("noop_metatypes")
142155
INPUT_NOOP_METATYPES = Registry("input_noop_metatypes")

nncf/experimental/torch/fx/nncf_graph_builder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
7474
:param model: Given GraphModule.
7575
:return: Node's type and metatype.
7676
"""
77+
node_type_name = None
7778
if node.op == "placeholder":
7879
node_type = "input"
7980
node_metatype = om.PTInputNoopMetatype
@@ -85,13 +86,15 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
8586
node_metatype = om.PTConstNoopMetatype
8687
elif node.op in ("call_function",):
8788
if hasattr(node.target, "overloadpacket"):
88-
node_type = str(node.target.overloadpacket).split(".")[1]
89+
node_type = str(node.target.overloadpacket)
90+
node_type_name = node_type.split(".")[1]
8991
elif node.target.__name__ == "getitem":
90-
node_type = "__getitem__"
92+
node_type = "aten.__getitem__"
93+
node_type_name = "__getitem__"
9194
else:
9295
# TODO(dlyakhov): get correct nodes types from this nodes as well
9396
node_type = str(node.target)
94-
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
97+
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_func(node_type)
9598
else:
9699
node_type = node.op
97100
node_metatype = UnknownMetatype

0 commit comments

Comments
 (0)