@@ -193,6 +193,7 @@ class PTNoopMetatype(PTOperatorMetatype):
193
193
NamespaceTarget .TORCH_NN_FUNCTIONAL : [],
194
194
NamespaceTarget .TORCH_TENSOR : ["contiguous" , "clone" , "detach" , "detach_" , "to" ],
195
195
NamespaceTarget .TORCH : ["clone" , "detach" , "detach_" ],
196
+ NamespaceTarget .ATEN : ["contiguous" , "clone" , "to" ],
196
197
}
197
198
198
199
@@ -246,7 +247,7 @@ class PTConv1dMetatype(PTOperatorMetatype):
246
247
class PTModuleDepthwiseConv2dSubtype (PTModuleDepthwiseConvOperatorSubtype ):
247
248
name = "Conv2DOp"
248
249
hw_config_names = [HWConfigOpName .DEPTHWISECONVOLUTION ]
249
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["conv2d" ]}
250
+ module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["conv2d" ], NamespaceTarget . ATEN : [ "conv2d" ] }
250
251
output_channel_axis = 1
251
252
num_expected_input_edges = 2
252
253
weight_port_ids = [1 ]
@@ -257,7 +258,7 @@ class PTModuleDepthwiseConv2dSubtype(PTModuleDepthwiseConvOperatorSubtype):
257
258
class PTModuleConv2dMetatype (PTModuleOperatorSubtype ):
258
259
name = "Conv2DOp"
259
260
hw_config_names = [HWConfigOpName .CONVOLUTION ]
260
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["conv2d" ]}
261
+ module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["conv2d" ], NamespaceTarget . ATEN : [ "conv2d" ] }
261
262
subtypes = [PTModuleDepthwiseConv2dSubtype ]
262
263
output_channel_axis = 1
263
264
num_expected_input_edges = 2
@@ -280,7 +281,7 @@ class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype):
280
281
class PTConv2dMetatype (PTOperatorMetatype ):
281
282
name = "Conv2DOp"
282
283
hw_config_names = [HWConfigOpName .CONVOLUTION ]
283
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["conv2d" ]}
284
+ module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["conv2d" ], NamespaceTarget . ATEN : [ "conv2d" ] }
284
285
subtypes = [PTModuleConv2dMetatype , PTDepthwiseConv2dSubtype ]
285
286
output_channel_axis = 1
286
287
num_expected_input_edges = 2
@@ -372,7 +373,10 @@ class PTModuleConvTranspose2dMetatype(PTModuleOperatorSubtype):
372
373
class PTConvTranspose2dMetatype (PTOperatorMetatype ):
373
374
name = "ConvTranspose2DOp"
374
375
hw_config_names = [HWConfigOpName .CONVOLUTION ]
375
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["conv_transpose2d" ]}
376
+ module_to_function_names = {
377
+ NamespaceTarget .TORCH_NN_FUNCTIONAL : ["conv_transpose2d" ],
378
+ NamespaceTarget .ATEN : ["conv_transpose2d" ],
379
+ }
376
380
subtypes = [PTModuleConvTranspose2dMetatype ]
377
381
output_channel_axis = 1
378
382
num_expected_input_edges = 2
@@ -433,7 +437,7 @@ class PTModuleLinearMetatype(PTModuleOperatorSubtype):
433
437
@PT_OPERATOR_METATYPES .register ()
434
438
class PTLinearMetatype (PTOperatorMetatype ):
435
439
name = "LinearOp"
436
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["linear" ]}
440
+ module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["linear" ], NamespaceTarget . ATEN : [ "linear" ] }
437
441
hw_config_names = [HWConfigOpName .MATMUL ]
438
442
subtypes = [PTModuleLinearMetatype ]
439
443
output_channel_axis = - 1
@@ -451,14 +455,20 @@ class PTHardTanhMetatype(PTOperatorMetatype):
451
455
@PT_OPERATOR_METATYPES .register ()
452
456
class PTHardSwishMetatype (PTOperatorMetatype ):
453
457
name = "HardSwishOp"
454
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["hardswish" , "hardswish_" ]}
458
+ module_to_function_names = {
459
+ NamespaceTarget .TORCH_NN_FUNCTIONAL : ["hardswish" , "hardswish_" ],
460
+ NamespaceTarget .ATEN : ["hardswish" , "hardswish_" ],
461
+ }
455
462
num_expected_input_edges = 1
456
463
457
464
458
465
@PT_OPERATOR_METATYPES .register ()
459
466
class PTHardSigmoidMetatype (PTOperatorMetatype ):
460
467
name = "HardSigmoidOp"
461
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["hardsigmoid" ]}
468
+ module_to_function_names = {
469
+ NamespaceTarget .TORCH_NN_FUNCTIONAL : ["hardsigmoid" ],
470
+ NamespaceTarget .ATEN : ["hardsigmoid" ],
471
+ }
462
472
num_expected_input_edges = 1
463
473
464
474
@@ -502,7 +512,10 @@ class PTModuleLayerNormMetatype(PTModuleOperatorSubtype):
502
512
@PT_OPERATOR_METATYPES .register ()
503
513
class PTLayerNormMetatype (PTOperatorMetatype ):
504
514
name = "LayerNormOp"
505
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["layer_norm" ]}
515
+ module_to_function_names = {
516
+ NamespaceTarget .TORCH_NN_FUNCTIONAL : ["layer_norm" ],
517
+ NamespaceTarget .ATEN : ["layer_norm" ],
518
+ }
506
519
hw_config_names = [HWConfigOpName .MVN ]
507
520
subtypes = [PTModuleLayerNormMetatype ]
508
521
num_expected_input_edges = 1
@@ -530,7 +543,7 @@ class PTGroupNormMetatype(PTOperatorMetatype):
530
543
class PTGELUMetatype (PTOperatorMetatype ):
531
544
name = "GeluOp"
532
545
hw_config_names = [HWConfigOpName .GELU ]
533
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["gelu" ]}
546
+ module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["gelu" ], NamespaceTarget . ATEN : [ "gelu" ] }
534
547
535
548
536
549
@PT_OPERATOR_METATYPES .register ()
@@ -546,6 +559,7 @@ class PTSigmoidMetatype(PTOperatorMetatype):
546
559
NamespaceTarget .TORCH_NN_FUNCTIONAL : ["sigmoid" ],
547
560
NamespaceTarget .TORCH_TENSOR : ["sigmoid" ],
548
561
NamespaceTarget .TORCH : ["sigmoid" ],
562
+ NamespaceTarget .ATEN : ["sigmoid" ],
549
563
}
550
564
551
565
@@ -561,6 +575,7 @@ class PTAddMetatype(PTOperatorMetatype):
561
575
"__radd__" ,
562
576
],
563
577
NamespaceTarget .TORCH : ["add" ],
578
+ NamespaceTarget .ATEN : ["add_" , "add" ],
564
579
}
565
580
hw_config_names = [HWConfigOpName .ADD ]
566
581
num_expected_input_edges = 2
@@ -578,6 +593,7 @@ class PTSubMetatype(PTOperatorMetatype):
578
593
"__rsub__" ,
579
594
],
580
595
NamespaceTarget .TORCH : ["sub" ],
596
+ NamespaceTarget .ATEN : ["sub" , "sub_" ],
581
597
}
582
598
hw_config_names = [HWConfigOpName .SUBTRACT ]
583
599
num_expected_input_edges = 2
@@ -589,6 +605,7 @@ class PTMulMetatype(PTOperatorMetatype):
589
605
module_to_function_names = {
590
606
NamespaceTarget .TORCH_TENSOR : ["mul" , "mul_" , "__mul__" , "__imul__" , "__rmul__" ],
591
607
NamespaceTarget .TORCH : ["mul" ],
608
+ NamespaceTarget .ATEN : ["mul" , "mul_" ],
592
609
}
593
610
hw_config_names = [HWConfigOpName .MULTIPLY ]
594
611
num_expected_input_edges = 2
@@ -609,6 +626,7 @@ class PTDivMetatype(PTOperatorMetatype):
609
626
"__rtruediv__" ,
610
627
],
611
628
NamespaceTarget .TORCH : ["div" ],
629
+ NamespaceTarget .ATEN : ["div" , "div_" ],
612
630
}
613
631
hw_config_names = [HWConfigOpName .DIVIDE ]
614
632
num_expected_input_edges = 2
@@ -630,6 +648,7 @@ class PTExpMetatype(PTOperatorMetatype):
630
648
module_to_function_names = {
631
649
NamespaceTarget .TORCH_TENSOR : ["exp" ],
632
650
NamespaceTarget .TORCH : ["exp" ],
651
+ NamespaceTarget .ATEN : ["exp" ],
633
652
}
634
653
635
654
@@ -665,6 +684,7 @@ class PTMatMulMetatype(PTOperatorMetatype):
665
684
module_to_function_names = {
666
685
NamespaceTarget .TORCH_TENSOR : ["matmul" , "__matmul__" , "__rmatmul__" ],
667
686
NamespaceTarget .TORCH : ["matmul" , "bmm" , "mm" ],
687
+ NamespaceTarget .ATEN : ["matmul" ],
668
688
}
669
689
hw_config_names = [HWConfigOpName .MATMUL ]
670
690
num_expected_input_edges = 2
@@ -687,7 +707,7 @@ class PTAddmmMetatype(PTOperatorMetatype):
687
707
@PT_OPERATOR_METATYPES .register ()
688
708
class PTMeanMetatype (PTOperatorMetatype ):
689
709
name = "MeanOp"
690
- module_to_function_names = {NamespaceTarget .TORCH_TENSOR : ["mean" ]}
710
+ module_to_function_names = {NamespaceTarget .TORCH_TENSOR : ["mean" ], NamespaceTarget . ATEN : [ "mean" ] }
691
711
hw_config_names = [HWConfigOpName .REDUCEMEAN ]
692
712
693
713
@@ -700,7 +720,11 @@ class PTRoundMetatype(PTOperatorMetatype):
700
720
@PT_OPERATOR_METATYPES .register ()
701
721
class PTDropoutMetatype (PTOperatorMetatype ):
702
722
name = "DropoutOp"
703
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["dropout" ], NamespaceTarget .TORCH : ["dropout_" ]}
723
+ module_to_function_names = {
724
+ NamespaceTarget .TORCH_NN_FUNCTIONAL : ["dropout" ],
725
+ NamespaceTarget .TORCH : ["dropout_" ],
726
+ NamespaceTarget .ATEN : ["dropout" , "dropout_" ],
727
+ }
704
728
705
729
706
730
@PT_OPERATOR_METATYPES .register ()
@@ -714,7 +738,7 @@ class PTModuleBatchNormMetatype(PTModuleOperatorSubtype):
714
738
name = "BatchNormOp"
715
739
module_to_function_names = {
716
740
NamespaceTarget .TORCH_NN_FUNCTIONAL : ["batch_norm" ],
717
- NamespaceTarget .ATEN : ["_native_batch_norm_legit_no_training" , "cudnn_batch_norm" ],
741
+ NamespaceTarget .ATEN : ["_native_batch_norm_legit_no_training" , "cudnn_batch_norm" , "batch_norm" ],
718
742
}
719
743
720
744
@@ -723,7 +747,7 @@ class PTBatchNormMetatype(PTOperatorMetatype):
723
747
name = "BatchNormOp"
724
748
module_to_function_names = {
725
749
NamespaceTarget .TORCH_NN_FUNCTIONAL : ["batch_norm" ],
726
- NamespaceTarget .ATEN : ["_native_batch_norm_legit_no_training" , "cudnn_batch_norm" ],
750
+ NamespaceTarget .ATEN : ["_native_batch_norm_legit_no_training" , "cudnn_batch_norm" , "batch_norm" ],
727
751
}
728
752
subtypes = [PTModuleBatchNormMetatype ]
729
753
weight_port_ids = [3 ]
@@ -733,7 +757,10 @@ class PTBatchNormMetatype(PTOperatorMetatype):
733
757
@PT_OPERATOR_METATYPES .register ()
734
758
class PTAvgPool2dMetatype (PTOperatorMetatype ):
735
759
name = "AvgPool2DOp"
736
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["avg_pool2d" , "adaptive_avg_pool2d" ]}
760
+ module_to_function_names = {
761
+ NamespaceTarget .TORCH_NN_FUNCTIONAL : ["avg_pool2d" , "adaptive_avg_pool2d" ],
762
+ NamespaceTarget .ATEN : ["adaptive_avg_pool2d" ],
763
+ }
737
764
hw_config_names = [HWConfigOpName .AVGPOOL ]
738
765
739
766
@@ -770,7 +797,10 @@ class PTMaxPool1dMetatype(PTOperatorMetatype):
770
797
@PT_OPERATOR_METATYPES .register ()
771
798
class PTMaxPool2dMetatype (PTOperatorMetatype ):
772
799
name = "MaxPool2DOp"
773
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["max_pool2d" ]}
800
+ module_to_function_names = {
801
+ NamespaceTarget .TORCH_NN_FUNCTIONAL : ["max_pool2d" ],
802
+ NamespaceTarget .ATEN : ["max_pool2d" ],
803
+ }
774
804
hw_config_names = [HWConfigOpName .MAXPOOL ]
775
805
776
806
@@ -802,20 +832,26 @@ class PTMaxUnpool3dMetatype(PTOperatorMetatype):
802
832
@PT_OPERATOR_METATYPES .register ()
803
833
class PTPadMetatype (PTOperatorMetatype ):
804
834
name = "PadOp"
805
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["pad" ]}
835
+ module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["pad" ], NamespaceTarget . ATEN : [ "pad" ] }
806
836
807
837
808
838
@PT_OPERATOR_METATYPES .register ()
809
839
class PTCatMetatype (PTOperatorMetatype ):
810
840
name = "CatOp"
811
- module_to_function_names = {NamespaceTarget .TORCH : ["cat" , "stack" , "concat" ]}
841
+ module_to_function_names = {
842
+ NamespaceTarget .TORCH : ["cat" , "stack" , "concat" ],
843
+ NamespaceTarget .ATEN : ["cat" , "concat" ],
844
+ }
812
845
hw_config_names = [HWConfigOpName .CONCAT ]
813
846
814
847
815
848
@PT_OPERATOR_METATYPES .register ()
816
849
class PTRELUMetatype (PTOperatorMetatype ):
817
850
name = "ReluOp"
818
- module_to_function_names = {NamespaceTarget .TORCH : ["relu" , "relu_" ]}
851
+ module_to_function_names = {
852
+ NamespaceTarget .TORCH : ["relu" , "relu_" ],
853
+ NamespaceTarget .ATEN : ["relu_" , "relu" ],
854
+ }
819
855
820
856
821
857
@PT_OPERATOR_METATYPES .register ()
@@ -827,14 +863,14 @@ class PTRELU6Metatype(PTOperatorMetatype):
827
863
@PT_OPERATOR_METATYPES .register ()
828
864
class PTMaxMetatype (PTOperatorMetatype ):
829
865
name = "MaxOp"
830
- module_to_function_names = {NamespaceTarget .TORCH : ["max" ]}
866
+ module_to_function_names = {NamespaceTarget .TORCH : ["max" ], NamespaceTarget . ATEN : [ "max" ] }
831
867
hw_config_names = [HWConfigOpName .MAXIMUM , HWConfigOpName .REDUCEMAX ]
832
868
833
869
834
870
@PT_OPERATOR_METATYPES .register ()
835
871
class PTMinMetatype (PTOperatorMetatype ):
836
872
name = "MinOp"
837
- module_to_function_names = {NamespaceTarget .TORCH : ["min" ]}
873
+ module_to_function_names = {NamespaceTarget .TORCH : ["min" ], NamespaceTarget . ATEN : [ "min" ] }
838
874
hw_config_names = [HWConfigOpName .MINIMUM ]
839
875
840
876
@@ -844,6 +880,7 @@ class PTTransposeMetatype(PTOperatorMetatype):
844
880
module_to_function_names = {
845
881
NamespaceTarget .TORCH_TENSOR : ["transpose" , "permute" , "transpose_" ],
846
882
NamespaceTarget .TORCH : ["transpose" ],
883
+ NamespaceTarget .ATEN : ["transpose" , "permute" , "transpose_" ],
847
884
}
848
885
hw_config_names = [HWConfigOpName .TRANSPOSE ]
849
886
@@ -854,14 +891,17 @@ class PTGatherMetatype(PTOperatorMetatype):
854
891
module_to_function_names = {
855
892
NamespaceTarget .TORCH_TENSOR : ["index_select" , "__getitem__" ],
856
893
NamespaceTarget .TORCH : ["gather" , "index_select" , "select" , "where" ],
857
- NamespaceTarget .ATEN : ["slice" ],
894
+ NamespaceTarget .ATEN : ["slice" , "select" , "__getitem__" ],
858
895
}
859
896
860
897
861
898
@PT_OPERATOR_METATYPES .register ()
862
899
class PTScatterMetatype (PTOperatorMetatype ):
863
900
name = "ScatterOp"
864
- module_to_function_names = {NamespaceTarget .TORCH_TENSOR : ["scatter" , "masked_fill" , "masked_fill_" ]}
901
+ module_to_function_names = {
902
+ NamespaceTarget .TORCH_TENSOR : ["scatter" , "masked_fill" , "masked_fill_" ],
903
+ NamespaceTarget .ATEN : ["masked_fill" ],
904
+ }
865
905
866
906
867
907
@PT_OPERATOR_METATYPES .register ()
@@ -870,6 +910,7 @@ class PTReshapeMetatype(PTOperatorMetatype):
870
910
module_to_function_names = {
871
911
NamespaceTarget .TORCH_TENSOR : ["reshape" , "view" , "flatten" , "unsqueeze" ],
872
912
NamespaceTarget .TORCH : ["flatten" , "unflatten" , "unsqueeze" ],
913
+ NamespaceTarget .ATEN : ["flatten" , "reshape" , "view" , "unsqueeze" , "unflatten" ],
873
914
}
874
915
hw_config_names = [HWConfigOpName .RESHAPE , HWConfigOpName .UNSQUEEZE , HWConfigOpName .FLATTEN ]
875
916
@@ -880,6 +921,7 @@ class PTSqueezeMetatype(PTOperatorMetatype):
880
921
module_to_function_names = {
881
922
NamespaceTarget .TORCH_TENSOR : ["squeeze" ],
882
923
NamespaceTarget .TORCH : ["squeeze" ],
924
+ NamespaceTarget .ATEN : ["squeeze" ],
883
925
}
884
926
hw_config_names = [HWConfigOpName .SQUEEZE ]
885
927
@@ -891,21 +933,21 @@ class PTSplitMetatype(PTOperatorMetatype):
891
933
NamespaceTarget .TORCH_NN_FUNCTIONAL : [],
892
934
NamespaceTarget .TORCH_TENSOR : ["split" , "chunk" , "unbind" ],
893
935
NamespaceTarget .TORCH : ["split" , "chunk" , "unbind" ],
894
- NamespaceTarget .ATEN : ["split_with_sizes" ],
936
+ NamespaceTarget .ATEN : ["split_with_sizes" , "split" ],
895
937
}
896
938
hw_config_names = [HWConfigOpName .SPLIT , HWConfigOpName .CHUNK ]
897
939
898
940
899
941
@PT_OPERATOR_METATYPES .register ()
900
942
class PTExpandMetatype (PTOperatorMetatype ):
901
943
name = "ExpandOp"
902
- module_to_function_names = {NamespaceTarget .TORCH_TENSOR : ["expand" ]}
944
+ module_to_function_names = {NamespaceTarget .TORCH_TENSOR : ["expand" ], NamespaceTarget . ATEN : [ "expand" ] }
903
945
904
946
905
947
@PT_OPERATOR_METATYPES .register ()
906
948
class PTExpandAsMetatype (PTOperatorMetatype ):
907
949
name = "ExpandAsOp"
908
- module_to_function_names = {NamespaceTarget .TORCH_TENSOR : ["expand_as" ]}
950
+ module_to_function_names = {NamespaceTarget .TORCH_TENSOR : ["expand_as" ], NamespaceTarget . ATEN : [ "expand_as" ] }
909
951
910
952
911
953
@PT_OPERATOR_METATYPES .register (is_subtype = True )
@@ -953,7 +995,7 @@ class PTEmbeddingBagMetatype(PTOperatorMetatype):
953
995
@PT_OPERATOR_METATYPES .register ()
954
996
class PTSoftmaxMetatype (PTOperatorMetatype ):
955
997
name = "SoftmaxOp"
956
- module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["softmax" ]}
998
+ module_to_function_names = {NamespaceTarget .TORCH_NN_FUNCTIONAL : ["softmax" ], NamespaceTarget . ATEN : [ "softmax" ] }
957
999
958
1000
959
1001
@PT_OPERATOR_METATYPES .register ()
@@ -1111,6 +1153,7 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype):
1111
1153
name = "ScaledDotProductAttentionOp"
1112
1154
module_to_function_names = {
1113
1155
NamespaceTarget .TORCH_NN_FUNCTIONAL : ["scaled_dot_product_attention" ],
1156
+ NamespaceTarget .ATEN : ["scaled_dot_product_attention" ],
1114
1157
}
1115
1158
hw_config_names = [HWConfigOpName .SCALED_DOT_PRODUCT_ATTENTION ]
1116
1159
target_input_ports = [0 , 1 ]
0 commit comments