Skip to content

Commit 5991aef

Browse files
committed
update operator metatypes
1 parent 98cb4dc commit 5991aef

File tree

1 file changed

+69
-26
lines changed

1 file changed

+69
-26
lines changed

nncf/torch/graph/operator_metatypes.py

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ class PTNoopMetatype(PTOperatorMetatype):
193193
NamespaceTarget.TORCH_NN_FUNCTIONAL: [],
194194
NamespaceTarget.TORCH_TENSOR: ["contiguous", "clone", "detach", "detach_", "to"],
195195
NamespaceTarget.TORCH: ["clone", "detach", "detach_"],
196+
NamespaceTarget.ATEN: ["contiguous", "clone", "to"],
196197
}
197198

198199

@@ -246,7 +247,7 @@ class PTConv1dMetatype(PTOperatorMetatype):
246247
class PTModuleDepthwiseConv2dSubtype(PTModuleDepthwiseConvOperatorSubtype):
247248
name = "Conv2DOp"
248249
hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION]
249-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"]}
250+
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"], NamespaceTarget.ATEN: ["conv2d"]}
250251
output_channel_axis = 1
251252
num_expected_input_edges = 2
252253
weight_port_ids = [1]
@@ -257,7 +258,7 @@ class PTModuleDepthwiseConv2dSubtype(PTModuleDepthwiseConvOperatorSubtype):
257258
class PTModuleConv2dMetatype(PTModuleOperatorSubtype):
258259
name = "Conv2DOp"
259260
hw_config_names = [HWConfigOpName.CONVOLUTION]
260-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"]}
261+
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"], NamespaceTarget.ATEN: ["conv2d"]}
261262
subtypes = [PTModuleDepthwiseConv2dSubtype]
262263
output_channel_axis = 1
263264
num_expected_input_edges = 2
@@ -280,7 +281,7 @@ class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype):
280281
class PTConv2dMetatype(PTOperatorMetatype):
281282
name = "Conv2DOp"
282283
hw_config_names = [HWConfigOpName.CONVOLUTION]
283-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"]}
284+
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"], NamespaceTarget.ATEN: ["conv2d"]}
284285
subtypes = [PTModuleConv2dMetatype, PTDepthwiseConv2dSubtype]
285286
output_channel_axis = 1
286287
num_expected_input_edges = 2
@@ -372,7 +373,10 @@ class PTModuleConvTranspose2dMetatype(PTModuleOperatorSubtype):
372373
class PTConvTranspose2dMetatype(PTOperatorMetatype):
373374
name = "ConvTranspose2DOp"
374375
hw_config_names = [HWConfigOpName.CONVOLUTION]
375-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv_transpose2d"]}
376+
module_to_function_names = {
377+
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv_transpose2d"],
378+
NamespaceTarget.ATEN: ["conv_transpose2d"],
379+
}
376380
subtypes = [PTModuleConvTranspose2dMetatype]
377381
output_channel_axis = 1
378382
num_expected_input_edges = 2
@@ -433,7 +437,7 @@ class PTModuleLinearMetatype(PTModuleOperatorSubtype):
433437
@PT_OPERATOR_METATYPES.register()
434438
class PTLinearMetatype(PTOperatorMetatype):
435439
name = "LinearOp"
436-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["linear"]}
440+
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["linear"], NamespaceTarget.ATEN: ["linear"]}
437441
hw_config_names = [HWConfigOpName.MATMUL]
438442
subtypes = [PTModuleLinearMetatype]
439443
output_channel_axis = -1
@@ -451,14 +455,20 @@ class PTHardTanhMetatype(PTOperatorMetatype):
451455
@PT_OPERATOR_METATYPES.register()
452456
class PTHardSwishMetatype(PTOperatorMetatype):
453457
name = "HardSwishOp"
454-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["hardswish", "hardswish_"]}
458+
module_to_function_names = {
459+
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["hardswish", "hardswish_"],
460+
NamespaceTarget.ATEN: ["hardswish", "hardswish_"],
461+
}
455462
num_expected_input_edges = 1
456463

457464

458465
@PT_OPERATOR_METATYPES.register()
459466
class PTHardSigmoidMetatype(PTOperatorMetatype):
460467
name = "HardSigmoidOp"
461-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["hardsigmoid"]}
468+
module_to_function_names = {
469+
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["hardsigmoid"],
470+
NamespaceTarget.ATEN: ["hardsigmoid"],
471+
}
462472
num_expected_input_edges = 1
463473

464474

@@ -502,7 +512,10 @@ class PTModuleLayerNormMetatype(PTModuleOperatorSubtype):
502512
@PT_OPERATOR_METATYPES.register()
503513
class PTLayerNormMetatype(PTOperatorMetatype):
504514
name = "LayerNormOp"
505-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["layer_norm"]}
515+
module_to_function_names = {
516+
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["layer_norm"],
517+
NamespaceTarget.ATEN: ["layer_norm"],
518+
}
506519
hw_config_names = [HWConfigOpName.MVN]
507520
subtypes = [PTModuleLayerNormMetatype]
508521
num_expected_input_edges = 1
@@ -530,7 +543,7 @@ class PTGroupNormMetatype(PTOperatorMetatype):
530543
class PTGELUMetatype(PTOperatorMetatype):
531544
name = "GeluOp"
532545
hw_config_names = [HWConfigOpName.GELU]
533-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["gelu"]}
546+
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["gelu"], NamespaceTarget.ATEN: ["gelu"]}
534547

535548

536549
@PT_OPERATOR_METATYPES.register()
@@ -546,6 +559,7 @@ class PTSigmoidMetatype(PTOperatorMetatype):
546559
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["sigmoid"],
547560
NamespaceTarget.TORCH_TENSOR: ["sigmoid"],
548561
NamespaceTarget.TORCH: ["sigmoid"],
562+
NamespaceTarget.ATEN: ["sigmoid"],
549563
}
550564

551565

@@ -561,6 +575,7 @@ class PTAddMetatype(PTOperatorMetatype):
561575
"__radd__",
562576
],
563577
NamespaceTarget.TORCH: ["add"],
578+
NamespaceTarget.ATEN: ["add_", "add"],
564579
}
565580
hw_config_names = [HWConfigOpName.ADD]
566581
num_expected_input_edges = 2
@@ -578,6 +593,7 @@ class PTSubMetatype(PTOperatorMetatype):
578593
"__rsub__",
579594
],
580595
NamespaceTarget.TORCH: ["sub"],
596+
NamespaceTarget.ATEN: ["sub", "sub_"],
581597
}
582598
hw_config_names = [HWConfigOpName.SUBTRACT]
583599
num_expected_input_edges = 2
@@ -589,6 +605,7 @@ class PTMulMetatype(PTOperatorMetatype):
589605
module_to_function_names = {
590606
NamespaceTarget.TORCH_TENSOR: ["mul", "mul_", "__mul__", "__imul__", "__rmul__"],
591607
NamespaceTarget.TORCH: ["mul"],
608+
NamespaceTarget.ATEN: ["mul", "mul_"],
592609
}
593610
hw_config_names = [HWConfigOpName.MULTIPLY]
594611
num_expected_input_edges = 2
@@ -609,6 +626,7 @@ class PTDivMetatype(PTOperatorMetatype):
609626
"__rtruediv__",
610627
],
611628
NamespaceTarget.TORCH: ["div"],
629+
NamespaceTarget.ATEN: ["div", "div_"],
612630
}
613631
hw_config_names = [HWConfigOpName.DIVIDE]
614632
num_expected_input_edges = 2
@@ -630,6 +648,7 @@ class PTExpMetatype(PTOperatorMetatype):
630648
module_to_function_names = {
631649
NamespaceTarget.TORCH_TENSOR: ["exp"],
632650
NamespaceTarget.TORCH: ["exp"],
651+
NamespaceTarget.ATEN: ["exp"],
633652
}
634653

635654

@@ -665,6 +684,7 @@ class PTMatMulMetatype(PTOperatorMetatype):
665684
module_to_function_names = {
666685
NamespaceTarget.TORCH_TENSOR: ["matmul", "__matmul__", "__rmatmul__"],
667686
NamespaceTarget.TORCH: ["matmul", "bmm", "mm"],
687+
NamespaceTarget.ATEN: ["matmul"],
668688
}
669689
hw_config_names = [HWConfigOpName.MATMUL]
670690
num_expected_input_edges = 2
@@ -687,7 +707,7 @@ class PTAddmmMetatype(PTOperatorMetatype):
687707
@PT_OPERATOR_METATYPES.register()
688708
class PTMeanMetatype(PTOperatorMetatype):
689709
name = "MeanOp"
690-
module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["mean"]}
710+
module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["mean"], NamespaceTarget.ATEN: ["mean"]}
691711
hw_config_names = [HWConfigOpName.REDUCEMEAN]
692712

693713

@@ -700,7 +720,11 @@ class PTRoundMetatype(PTOperatorMetatype):
700720
@PT_OPERATOR_METATYPES.register()
701721
class PTDropoutMetatype(PTOperatorMetatype):
702722
name = "DropoutOp"
703-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["dropout"], NamespaceTarget.TORCH: ["dropout_"]}
723+
module_to_function_names = {
724+
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["dropout"],
725+
NamespaceTarget.TORCH: ["dropout_"],
726+
NamespaceTarget.ATEN: ["dropout", "dropout_"],
727+
}
704728

705729

706730
@PT_OPERATOR_METATYPES.register()
@@ -714,7 +738,7 @@ class PTModuleBatchNormMetatype(PTModuleOperatorSubtype):
714738
name = "BatchNormOp"
715739
module_to_function_names = {
716740
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"],
717-
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training", "cudnn_batch_norm"],
741+
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training", "cudnn_batch_norm", "batch_norm"],
718742
}
719743

720744

@@ -723,7 +747,7 @@ class PTBatchNormMetatype(PTOperatorMetatype):
723747
name = "BatchNormOp"
724748
module_to_function_names = {
725749
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"],
726-
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training", "cudnn_batch_norm"],
750+
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training", "cudnn_batch_norm", "batch_norm"],
727751
}
728752
subtypes = [PTModuleBatchNormMetatype]
729753
weight_port_ids = [3]
@@ -733,7 +757,10 @@ class PTBatchNormMetatype(PTOperatorMetatype):
733757
@PT_OPERATOR_METATYPES.register()
734758
class PTAvgPool2dMetatype(PTOperatorMetatype):
735759
name = "AvgPool2DOp"
736-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["avg_pool2d", "adaptive_avg_pool2d"]}
760+
module_to_function_names = {
761+
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["avg_pool2d", "adaptive_avg_pool2d"],
762+
NamespaceTarget.ATEN: ["adaptive_avg_pool2d"],
763+
}
737764
hw_config_names = [HWConfigOpName.AVGPOOL]
738765

739766

@@ -770,7 +797,10 @@ class PTMaxPool1dMetatype(PTOperatorMetatype):
770797
@PT_OPERATOR_METATYPES.register()
771798
class PTMaxPool2dMetatype(PTOperatorMetatype):
772799
name = "MaxPool2DOp"
773-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["max_pool2d"]}
800+
module_to_function_names = {
801+
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["max_pool2d"],
802+
NamespaceTarget.ATEN: ["max_pool2d"],
803+
}
774804
hw_config_names = [HWConfigOpName.MAXPOOL]
775805

776806

@@ -802,20 +832,26 @@ class PTMaxUnpool3dMetatype(PTOperatorMetatype):
802832
@PT_OPERATOR_METATYPES.register()
803833
class PTPadMetatype(PTOperatorMetatype):
804834
name = "PadOp"
805-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["pad"]}
835+
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["pad"], NamespaceTarget.ATEN: ["pad"]}
806836

807837

808838
@PT_OPERATOR_METATYPES.register()
809839
class PTCatMetatype(PTOperatorMetatype):
810840
name = "CatOp"
811-
module_to_function_names = {NamespaceTarget.TORCH: ["cat", "stack", "concat"]}
841+
module_to_function_names = {
842+
NamespaceTarget.TORCH: ["cat", "stack", "concat"],
843+
NamespaceTarget.ATEN: ["cat", "concat"],
844+
}
812845
hw_config_names = [HWConfigOpName.CONCAT]
813846

814847

815848
@PT_OPERATOR_METATYPES.register()
816849
class PTRELUMetatype(PTOperatorMetatype):
817850
name = "ReluOp"
818-
module_to_function_names = {NamespaceTarget.TORCH: ["relu", "relu_"]}
851+
module_to_function_names = {
852+
NamespaceTarget.TORCH: ["relu", "relu_"],
853+
NamespaceTarget.ATEN: ["relu_", "relu"],
854+
}
819855

820856

821857
@PT_OPERATOR_METATYPES.register()
@@ -827,14 +863,14 @@ class PTRELU6Metatype(PTOperatorMetatype):
827863
@PT_OPERATOR_METATYPES.register()
828864
class PTMaxMetatype(PTOperatorMetatype):
829865
name = "MaxOp"
830-
module_to_function_names = {NamespaceTarget.TORCH: ["max"]}
866+
module_to_function_names = {NamespaceTarget.TORCH: ["max"], NamespaceTarget.ATEN: ["max"]}
831867
hw_config_names = [HWConfigOpName.MAXIMUM, HWConfigOpName.REDUCEMAX]
832868

833869

834870
@PT_OPERATOR_METATYPES.register()
835871
class PTMinMetatype(PTOperatorMetatype):
836872
name = "MinOp"
837-
module_to_function_names = {NamespaceTarget.TORCH: ["min"]}
873+
module_to_function_names = {NamespaceTarget.TORCH: ["min"], NamespaceTarget.ATEN: ["min"]}
838874
hw_config_names = [HWConfigOpName.MINIMUM]
839875

840876

@@ -844,6 +880,7 @@ class PTTransposeMetatype(PTOperatorMetatype):
844880
module_to_function_names = {
845881
NamespaceTarget.TORCH_TENSOR: ["transpose", "permute", "transpose_"],
846882
NamespaceTarget.TORCH: ["transpose"],
883+
NamespaceTarget.ATEN: ["transpose", "permute", "transpose_"],
847884
}
848885
hw_config_names = [HWConfigOpName.TRANSPOSE]
849886

@@ -854,14 +891,17 @@ class PTGatherMetatype(PTOperatorMetatype):
854891
module_to_function_names = {
855892
NamespaceTarget.TORCH_TENSOR: ["index_select", "__getitem__"],
856893
NamespaceTarget.TORCH: ["gather", "index_select", "select", "where"],
857-
NamespaceTarget.ATEN: ["slice"],
894+
NamespaceTarget.ATEN: ["slice", "select", "__getitem__"],
858895
}
859896

860897

861898
@PT_OPERATOR_METATYPES.register()
862899
class PTScatterMetatype(PTOperatorMetatype):
863900
name = "ScatterOp"
864-
module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["scatter", "masked_fill", "masked_fill_"]}
901+
module_to_function_names = {
902+
NamespaceTarget.TORCH_TENSOR: ["scatter", "masked_fill", "masked_fill_"],
903+
NamespaceTarget.ATEN: ["masked_fill"],
904+
}
865905

866906

867907
@PT_OPERATOR_METATYPES.register()
@@ -870,6 +910,7 @@ class PTReshapeMetatype(PTOperatorMetatype):
870910
module_to_function_names = {
871911
NamespaceTarget.TORCH_TENSOR: ["reshape", "view", "flatten", "unsqueeze"],
872912
NamespaceTarget.TORCH: ["flatten", "unflatten", "unsqueeze"],
913+
NamespaceTarget.ATEN: ["flatten", "reshape", "view", "unsqueeze", "unflatten"],
873914
}
874915
hw_config_names = [HWConfigOpName.RESHAPE, HWConfigOpName.UNSQUEEZE, HWConfigOpName.FLATTEN]
875916

@@ -880,6 +921,7 @@ class PTSqueezeMetatype(PTOperatorMetatype):
880921
module_to_function_names = {
881922
NamespaceTarget.TORCH_TENSOR: ["squeeze"],
882923
NamespaceTarget.TORCH: ["squeeze"],
924+
NamespaceTarget.ATEN: ["squeeze"],
883925
}
884926
hw_config_names = [HWConfigOpName.SQUEEZE]
885927

@@ -891,21 +933,21 @@ class PTSplitMetatype(PTOperatorMetatype):
891933
NamespaceTarget.TORCH_NN_FUNCTIONAL: [],
892934
NamespaceTarget.TORCH_TENSOR: ["split", "chunk", "unbind"],
893935
NamespaceTarget.TORCH: ["split", "chunk", "unbind"],
894-
NamespaceTarget.ATEN: ["split_with_sizes"],
936+
NamespaceTarget.ATEN: ["split_with_sizes", "split"],
895937
}
896938
hw_config_names = [HWConfigOpName.SPLIT, HWConfigOpName.CHUNK]
897939

898940

899941
@PT_OPERATOR_METATYPES.register()
900942
class PTExpandMetatype(PTOperatorMetatype):
901943
name = "ExpandOp"
902-
module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["expand"]}
944+
module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["expand"], NamespaceTarget.ATEN: ["expand"]}
903945

904946

905947
@PT_OPERATOR_METATYPES.register()
906948
class PTExpandAsMetatype(PTOperatorMetatype):
907949
name = "ExpandAsOp"
908-
module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["expand_as"]}
950+
module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["expand_as"], NamespaceTarget.ATEN: ["expand_as"]}
909951

910952

911953
@PT_OPERATOR_METATYPES.register(is_subtype=True)
@@ -953,7 +995,7 @@ class PTEmbeddingBagMetatype(PTOperatorMetatype):
953995
@PT_OPERATOR_METATYPES.register()
954996
class PTSoftmaxMetatype(PTOperatorMetatype):
955997
name = "SoftmaxOp"
956-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["softmax"]}
998+
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["softmax"], NamespaceTarget.ATEN: ["softmax"]}
957999

9581000

9591001
@PT_OPERATOR_METATYPES.register()
@@ -1111,6 +1153,7 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype):
11111153
name = "ScaledDotProductAttentionOp"
11121154
module_to_function_names = {
11131155
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["scaled_dot_product_attention"],
1156+
NamespaceTarget.ATEN: ["scaled_dot_product_attention"],
11141157
}
11151158
hw_config_names = [HWConfigOpName.SCALED_DOT_PRODUCT_ATTENTION]
11161159
target_input_ports = [0, 1]

0 commit comments

Comments
 (0)