Skip to content

Commit d53b445

Browse files
committed
fix aten op error
1 parent 22c5e51 commit d53b445

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

nncf/experimental/torch/fx/nncf_graph_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
2525
from nncf.torch.dynamic_graph.layer_attributes_handlers import apply_args_defaults
2626
from nncf.torch.graph.graph import PTNNCFGraph
27-
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES
27+
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES, FX_OPERATOR_METATYPES
2828

2929

3030
class GraphConverter:
@@ -95,6 +95,7 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
9595
# TODO(dlyakhov): get correct nodes types from this nodes as well
9696
node_type = str(node.target)
9797
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_func(node_type)
98+
node_metatype = FX_OPERATOR_METATYPES.get_operator_metatype_by_func(node_type) if node_metatype == UnknownMetatype else node_metatype
9899
else:
99100
node_type = node.op
100101
node_metatype = UnknownMetatype

nncf/torch/graph/operator_metatypes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ModuleAttributes = TypeVar("ModuleAttributes", bound=BaseLayerAttributes)
2929

3030
PT_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes")
31+
FX_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes")
3132

3233

3334
class PTOperatorMetatype(OperatorMetatype):
@@ -967,7 +968,7 @@ class PTEmbeddingMetatype(PTOperatorMetatype):
967968
weight_port_ids = [1]
968969

969970

970-
@PT_OPERATOR_METATYPES.register()
971+
@FX_OPERATOR_METATYPES.register()
971972
class PTAtenEmbeddingMetatype(PTOperatorMetatype):
972973
name = "AtenEmbeddingOp"
973974
module_to_function_names = {NamespaceTarget.ATEN: ["embedding"]}

0 commit comments

Comments
 (0)