Skip to content

Commit

Permalink
fix aten op error
Browse files Browse the repository at this point in the history
  • Loading branch information
anzr299 committed Feb 6, 2025
1 parent 22c5e51 commit d53b445
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.torch.dynamic_graph.layer_attributes_handlers import apply_args_defaults
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES, FX_OPERATOR_METATYPES


class GraphConverter:
Expand Down Expand Up @@ -95,6 +95,7 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
# TODO(dlyakhov): get correct nodes types from this nodes as well
node_type = str(node.target)
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_func(node_type)
node_metatype = FX_OPERATOR_METATYPES.get_operator_metatype_by_func(node_type) if node_metatype == UnknownMetatype else node_metatype
else:
node_type = node.op
node_metatype = UnknownMetatype
Expand Down
3 changes: 2 additions & 1 deletion nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ModuleAttributes = TypeVar("ModuleAttributes", bound=BaseLayerAttributes)

PT_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes")
FX_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes")


class PTOperatorMetatype(OperatorMetatype):
Expand Down Expand Up @@ -967,7 +968,7 @@ class PTEmbeddingMetatype(PTOperatorMetatype):
weight_port_ids = [1]


@PT_OPERATOR_METATYPES.register()
@FX_OPERATOR_METATYPES.register()
class PTAtenEmbeddingMetatype(PTOperatorMetatype):
name = "AtenEmbeddingOp"
module_to_function_names = {NamespaceTarget.ATEN: ["embedding"]}
Expand Down

0 comments on commit d53b445

Please sign in to comment.