Skip to content

Commit 9a8a268

Browse files
TheCodezfacebook-github-bot
authored andcommitted
add index and count to list (pytorch#17446)
Summary: see pytorch#16662 Pull Request resolved: pytorch#17446 Differential Revision: D14461293 Pulled By: Krovatkin fbshipit-source-id: 03572467cdf85efc909c1864c0558a93085c8ff3
1 parent 001cffe commit 9a8a268

File tree

2 files changed

+157
-2
lines changed

2 files changed

+157
-2
lines changed

test/test_jit.py

+70
Original file line numberDiff line numberDiff line change
@@ -4315,6 +4315,76 @@ def test_list_remove():
43154315
return a == [1, 2, 4]
43164316
self.checkScript(test_list_remove, ())
43174317

4318+
def test_list_index_not_existing(self):
4319+
@torch.jit.script
4320+
def list_index_not_existing():
4321+
a = [4, 1, 3, 2]
4322+
i = a.index(5)
4323+
4324+
return i
4325+
4326+
with self.assertRaisesRegex(RuntimeError, "'5' is not in list"):
4327+
list_index_not_existing()
4328+
4329+
def test_list_index(self):
4330+
def list_index():
4331+
a = [4, 1, 3, 2]
4332+
i = a.index(3)
4333+
4334+
return i == 2
4335+
self.checkScript(list_index, ())
4336+
4337+
def test_tensor_list_index(self):
4338+
def tensor_list_index():
4339+
a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
4340+
i = a.index(torch.tensor(3))
4341+
4342+
return i == 2
4343+
self.checkScript(tensor_list_index, ())
4344+
4345+
def test_tensor_list_index_not_existing(self):
4346+
@torch.jit.script
4347+
def tensor_list_index_not_existing():
4348+
a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
4349+
i = a.index(torch.tensor(5))
4350+
4351+
return i
4352+
4353+
with self.assertRaisesRegex(RuntimeError, "is not in list"):
4354+
tensor_list_index_not_existing()
4355+
4356+
def test_list_count(self):
4357+
def list_count():
4358+
a = [4, 1, 4, 2, 4]
4359+
i = a.count(4)
4360+
4361+
return i == 3
4362+
self.checkScript(list_count, ())
4363+
4364+
def test_list_count_not_existing(self):
4365+
def list_count_not_existing():
4366+
a = [4, 1, 4, 2, 4]
4367+
i = a.count(5)
4368+
4369+
return i == 0
4370+
self.checkScript(list_count_not_existing, ())
4371+
4372+
def test_tensor_list_count(self):
4373+
def tensor_list_count():
4374+
a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
4375+
i = a.count(torch.tensor(4))
4376+
4377+
return i == 3
4378+
self.checkScript(tensor_list_count, ())
4379+
4380+
def test_tensor_list_count_not_existing(self):
4381+
def tensor_list_count_not_existing():
4382+
a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
4383+
i = a.count(torch.tensor(5))
4384+
4385+
return i == 0
4386+
self.checkScript(tensor_list_count_not_existing, ())
4387+
43184388
def test_mutable_list_remove_tensor(self):
43194389
def test_list_remove_tensor():
43204390
a = [torch.ones(1), torch.zeros(1), torch.ones(2)]

torch/csrc/jit/register_prim_ops.cpp

+87-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
5454
"Cannot input a tensor of dimension other than 0 as a scalar argument");
5555
}
5656
if (toInt &&
57-
!isIntegralType(
58-
autograd::as_variable_ref(t).data().scalar_type())) {
57+
!isIntegralType(autograd::as_variable_ref(t).data().scalar_type())) {
5958
std::stringstream ss;
6059
ss << "Cannot input a tensor of type " << t.scalar_type()
6160
<< " as an integral argument";
@@ -1114,6 +1113,76 @@ int listRemove<Shared<TensorList>, at::Tensor>(Stack& stack) {
11141113
return 0;
11151114
}
11161115

1116+
template <typename TList, typename TElement>
1117+
int listIndex(Stack& stack) {
1118+
TList list;
1119+
TElement elem;
1120+
pop(stack, list, elem);
1121+
1122+
auto& elements = list->elements();
1123+
auto pos = std::find(elements.begin(), elements.end(), elem);
1124+
1125+
if (pos != elements.end()) {
1126+
push(stack, static_cast<int64_t>(std::distance(elements.begin(), pos)));
1127+
} else {
1128+
AT_ERROR("'", elem, "' is not in list");
1129+
}
1130+
1131+
return 0;
1132+
}
1133+
1134+
template <>
1135+
int listIndex<Shared<TensorList>, at::Tensor>(Stack& stack) {
1136+
Shared<TensorList> list;
1137+
at::Tensor elem;
1138+
pop(stack, list, elem);
1139+
1140+
auto& elements = list->elements();
1141+
auto pos = std::find_if(
1142+
elements.begin(), elements.end(), [elem](const at::Tensor& b) {
1143+
const auto cmp_result = elem.eq(b);
1144+
return cmp_result.is_nonzero();
1145+
});
1146+
1147+
if (pos != elements.end()) {
1148+
push(stack, static_cast<int64_t>(std::distance(elements.begin(), pos)));
1149+
} else {
1150+
AT_ERROR("'", elem, "' is not in list");
1151+
}
1152+
1153+
return 0;
1154+
}
1155+
1156+
template <typename TList, typename TElement>
1157+
int listCount(Stack& stack) {
1158+
TList list;
1159+
TElement elem;
1160+
pop(stack, list, elem);
1161+
1162+
auto& elements = list->elements();
1163+
const int64_t count = std::count(elements.begin(), elements.end(), elem);
1164+
push(stack, count);
1165+
1166+
return 0;
1167+
}
1168+
1169+
template <>
1170+
int listCount<Shared<TensorList>, at::Tensor>(Stack& stack) {
1171+
Shared<TensorList> list;
1172+
at::Tensor elem;
1173+
pop(stack, list, elem);
1174+
1175+
auto& elements = list->elements();
1176+
const int64_t count = std::count_if(
1177+
elements.begin(), elements.end(), [elem](const at::Tensor& b) {
1178+
const auto cmp_result = elem.eq(b);
1179+
return cmp_result.is_nonzero();
1180+
});
1181+
push(stack, count);
1182+
1183+
return 0;
1184+
}
1185+
11171186
template <typename TList>
11181187
Operation listExtend(const Node* node) {
11191188
return [](Stack& stack) {
@@ -1485,6 +1554,12 @@ RegisterOperators reg2({
14851554
Operator(
14861555
"aten::remove(Tensor[](a!) self, Tensor el) -> ()",
14871556
listRemove<Shared<TensorList>, at::Tensor>),
1557+
Operator(
1558+
"aten::index(Tensor[] self, Tensor el) -> int",
1559+
listIndex<Shared<TensorList>, at::Tensor>),
1560+
Operator(
1561+
"aten::count(Tensor[] self, Tensor el) -> int",
1562+
listCount<Shared<TensorList>, at::Tensor>),
14881563

14891564
// Mutable ops for lists containing immutable types.
14901565
#define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \
@@ -1524,6 +1599,16 @@ RegisterOperators reg2({
15241599
"[](a!) self, \
15251600
" decl_type " el) -> ()", \
15261601
listRemove<Shared<c_type>, c_type::ElemType>), \
1602+
Operator( \
1603+
"aten::index(" decl_type \
1604+
"[] self, \
1605+
" decl_type " el) -> int", \
1606+
listIndex<Shared<c_type>, c_type::ElemType>), \
1607+
Operator( \
1608+
"aten::count(" decl_type \
1609+
"[] self, \
1610+
" decl_type " el) -> int", \
1611+
listCount<Shared<c_type>, c_type::ElemType>), \
15271612
Operator( \
15281613
"aten::pop(" decl_type \
15291614
"[](a!) self, int idx=-1) \

0 commit comments

Comments
 (0)