@@ -54,8 +54,7 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
54
54
" Cannot input a tensor of dimension other than 0 as a scalar argument" );
55
55
}
56
56
if (toInt &&
57
- !isIntegralType (
58
- autograd::as_variable_ref (t).data ().scalar_type ())) {
57
+ !isIntegralType (autograd::as_variable_ref (t).data ().scalar_type ())) {
59
58
std::stringstream ss;
60
59
ss << " Cannot input a tensor of type " << t.scalar_type ()
61
60
<< " as an integral argument" ;
@@ -1114,6 +1113,76 @@ int listRemove<Shared<TensorList>, at::Tensor>(Stack& stack) {
1114
1113
return 0 ;
1115
1114
}
1116
1115
1116
+ template <typename TList, typename TElement>
1117
+ int listIndex (Stack& stack) {
1118
+ TList list;
1119
+ TElement elem;
1120
+ pop (stack, list, elem);
1121
+
1122
+ auto & elements = list->elements ();
1123
+ auto pos = std::find (elements.begin (), elements.end (), elem);
1124
+
1125
+ if (pos != elements.end ()) {
1126
+ push (stack, static_cast <int64_t >(std::distance (elements.begin (), pos)));
1127
+ } else {
1128
+ AT_ERROR (" '" , elem, " ' is not in list" );
1129
+ }
1130
+
1131
+ return 0 ;
1132
+ }
1133
+
1134
+ template <>
1135
+ int listIndex<Shared<TensorList>, at::Tensor>(Stack& stack) {
1136
+ Shared<TensorList> list;
1137
+ at::Tensor elem;
1138
+ pop (stack, list, elem);
1139
+
1140
+ auto & elements = list->elements ();
1141
+ auto pos = std::find_if (
1142
+ elements.begin (), elements.end (), [elem](const at::Tensor& b) {
1143
+ const auto cmp_result = elem.eq (b);
1144
+ return cmp_result.is_nonzero ();
1145
+ });
1146
+
1147
+ if (pos != elements.end ()) {
1148
+ push (stack, static_cast <int64_t >(std::distance (elements.begin (), pos)));
1149
+ } else {
1150
+ AT_ERROR (" '" , elem, " ' is not in list" );
1151
+ }
1152
+
1153
+ return 0 ;
1154
+ }
1155
+
1156
+ template <typename TList, typename TElement>
1157
+ int listCount (Stack& stack) {
1158
+ TList list;
1159
+ TElement elem;
1160
+ pop (stack, list, elem);
1161
+
1162
+ auto & elements = list->elements ();
1163
+ const int64_t count = std::count (elements.begin (), elements.end (), elem);
1164
+ push (stack, count);
1165
+
1166
+ return 0 ;
1167
+ }
1168
+
1169
+ template <>
1170
+ int listCount<Shared<TensorList>, at::Tensor>(Stack& stack) {
1171
+ Shared<TensorList> list;
1172
+ at::Tensor elem;
1173
+ pop (stack, list, elem);
1174
+
1175
+ auto & elements = list->elements ();
1176
+ const int64_t count = std::count_if (
1177
+ elements.begin (), elements.end (), [elem](const at::Tensor& b) {
1178
+ const auto cmp_result = elem.eq (b);
1179
+ return cmp_result.is_nonzero ();
1180
+ });
1181
+ push (stack, count);
1182
+
1183
+ return 0 ;
1184
+ }
1185
+
1117
1186
template <typename TList>
1118
1187
Operation listExtend (const Node* node) {
1119
1188
return [](Stack& stack) {
@@ -1485,6 +1554,12 @@ RegisterOperators reg2({
1485
1554
Operator (
1486
1555
" aten::remove(Tensor[](a!) self, Tensor el) -> ()" ,
1487
1556
listRemove<Shared<TensorList>, at::Tensor>),
1557
+ Operator (
1558
+ " aten::index(Tensor[] self, Tensor el) -> int" ,
1559
+ listIndex<Shared<TensorList>, at::Tensor>),
1560
+ Operator (
1561
+ " aten::count(Tensor[] self, Tensor el) -> int" ,
1562
+ listCount<Shared<TensorList>, at::Tensor>),
1488
1563
1489
1564
// Mutable ops for lists containing immutable types.
1490
1565
#define CREATE_IMMUTABLE_LIST_OPS (decl_type, c_type ) \
@@ -1524,6 +1599,16 @@ RegisterOperators reg2({
1524
1599
" [](a!) self, \
1525
1600
" decl_type " el) -> ()" , \
1526
1601
listRemove<Shared<c_type>, c_type::ElemType>), \
1602
+ Operator ( \
1603
+ " aten::index(" decl_type \
1604
+ " [] self, \
1605
+ " decl_type " el) -> int" , \
1606
+ listIndex<Shared<c_type>, c_type::ElemType>), \
1607
+ Operator ( \
1608
+ " aten::count(" decl_type \
1609
+ " [] self, \
1610
+ " decl_type " el) -> int" , \
1611
+ listCount<Shared<c_type>, c_type::ElemType>), \
1527
1612
Operator ( \
1528
1613
" aten::pop(" decl_type \
1529
1614
" [](a!) self, int idx=-1) \
0 commit comments