Skip to content

Commit 2e52d2e

Browse files
committed
Added support for Mapping, Set and Sequence derived from collections.abc.
1 parent 2e382c0 commit 2e52d2e

File tree

3 files changed

+145
-2
lines changed

3 files changed

+145
-2
lines changed

include/pybind11/stl.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,14 @@ struct set_caster {
173173
public:
174174
bool load(handle src, bool convert) {
175175
if (!PyObjectTypeIsConvertibleToStdSet(src.ptr())) {
176-
return false;
176+
if (!convert) {
177+
return false;
178+
}
179+
if (!(isinstance(src, module_::import("collections.abc").attr("Set"))
180+
&& hasattr(src, "__contains__") && hasattr(src, "__iter__")
181+
&& hasattr(src, "__len__"))) {
182+
return false;
183+
}
177184
}
178185
if (isinstance<anyset>(src)) {
179186
value.clear();
@@ -237,7 +244,14 @@ struct map_caster {
237244
public:
238245
bool load(handle src, bool convert) {
239246
if (!PyObjectTypeIsConvertibleToStdMap(src.ptr())) {
240-
return false;
247+
if (!convert) {
248+
return false;
249+
}
250+
if (!(isinstance(src, module_::import("collections.abc").attr("Mapping"))
251+
&& hasattr(src, "__getitem__") && hasattr(src, "__iter__")
252+
&& hasattr(src, "__len__"))) {
253+
return false;
254+
}
241255
}
242256
if (isinstance<dict>(src)) {
243257
return convert_elements(reinterpret_borrow<dict>(src), convert);

tests/test_stl.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -648,4 +648,7 @@ TEST_SUBMODULE(stl, m) {
648648
}
649649
return zum;
650650
});
651+
m.def("roundtrip_std_vector_int", [](const std::vector<int> &v) { return v; });
652+
m.def("roundtrip_std_map_str_int", [](const std::map<std::string, int> &m) { return m; });
653+
m.def("roundtrip_std_set_int", [](const std::set<int> &s) { return s; });
651654
}

tests/test_stl.py

+126
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,129 @@ def gen_invalid():
574574
with pytest.raises(expected_exception):
575575
m.pass_std_map_int(FakePyMappingGenObj(gen_obj))
576576
assert not tuple(gen_obj)
577+
578+
579+
def test_sequence_caster_protocol(doc):
580+
from collections.abc import Sequence
581+
582+
class SequenceLike(Sequence):
583+
def __init__(self, *args):
584+
self.data = tuple(args)
585+
586+
def __len__(self):
587+
return len(self.data)
588+
589+
def __getitem__(self, index):
590+
return self.data[index]
591+
592+
class FakeSequenceLike:
593+
def __init__(self, *args):
594+
self.data = tuple(args)
595+
596+
def __len__(self):
597+
return len(self.data)
598+
599+
def __getitem__(self, index):
600+
return self.data[index]
601+
602+
assert (
603+
doc(m.roundtrip_std_vector_int)
604+
== "roundtrip_std_vector_int(arg0: collections.abc.Sequence[int]) -> list[int]"
605+
)
606+
assert m.roundtrip_std_vector_int([1, 2, 3]) == [1, 2, 3]
607+
assert m.roundtrip_std_vector_int((1, 2, 3)) == [1, 2, 3]
608+
assert m.roundtrip_std_vector_int(SequenceLike(1, 2, 3)) == [1, 2, 3]
609+
assert m.roundtrip_std_vector_int(FakeSequenceLike(1, 2, 3)) == [1, 2, 3]
610+
assert m.roundtrip_std_vector_int([]) == []
611+
assert m.roundtrip_std_vector_int(()) == []
612+
assert m.roundtrip_std_vector_int(FakeSequenceLike()) == []
613+
614+
615+
def test_mapping_caster_protocol(doc):
616+
from collections.abc import Mapping
617+
618+
class MappingLike(Mapping):
619+
def __init__(self, **kwargs):
620+
self.data = dict(kwargs)
621+
622+
def __len__(self):
623+
return len(self.data)
624+
625+
def __getitem__(self, key):
626+
return self.data[key]
627+
628+
def __iter__(self):
629+
yield from self.data
630+
631+
class FakeMappingLike:
632+
def __init__(self, **kwargs):
633+
self.data = dict(kwargs)
634+
635+
def __len__(self):
636+
return len(self.data)
637+
638+
def __getitem__(self, key):
639+
return self.data[key]
640+
641+
def __iter__(self):
642+
yield from self.data
643+
644+
assert (
645+
doc(m.roundtrip_std_map_str_int)
646+
== "roundtrip_std_map_str_int(arg0: collections.abc.Mapping[str, int]) -> dict[str, int]"
647+
)
648+
assert m.roundtrip_std_map_str_int({"a": 1, "b": 2, "c": 3}) == {
649+
"a": 1,
650+
"b": 2,
651+
"c": 3,
652+
}
653+
assert m.roundtrip_std_map_str_int(MappingLike(a=1, b=2, c=3)) == {
654+
"a": 1,
655+
"b": 2,
656+
"c": 3,
657+
}
658+
assert m.roundtrip_std_map_str_int({}) == {}
659+
assert m.roundtrip_std_map_str_int(MappingLike()) == {}
660+
with pytest.raises(TypeError):
661+
m.roundtrip_std_map_str_int(FakeMappingLike(a=1, b=2, c=3))
662+
663+
664+
def test_set_caster_protocol(doc):
665+
from collections.abc import Set
666+
667+
class SetLike(Set):
668+
def __init__(self, *args):
669+
self.data = set(args)
670+
671+
def __len__(self):
672+
return len(self.data)
673+
674+
def __contains__(self, item):
675+
return item in self.data
676+
677+
def __iter__(self):
678+
yield from self.data
679+
680+
class FakeSetLike:
681+
def __init__(self, *args):
682+
self.data = set(args)
683+
684+
def __len__(self):
685+
return len(self.data)
686+
687+
def __contains__(self, item):
688+
return item in self.data
689+
690+
def __iter__(self):
691+
yield from self.data
692+
693+
assert (
694+
doc(m.roundtrip_std_set_int)
695+
== "roundtrip_std_set_int(arg0: collections.abc.Set[int]) -> set[int]"
696+
)
697+
assert m.roundtrip_std_set_int({1, 2, 3}) == {1, 2, 3}
698+
assert m.roundtrip_std_set_int(SetLike(1, 2, 3)) == {1, 2, 3}
699+
assert m.roundtrip_std_set_int(set()) == set()
700+
assert m.roundtrip_std_set_int(SetLike()) == set()
701+
with pytest.raises(TypeError):
702+
m.roundtrip_std_set_int(FakeSetLike(1, 2, 3))

0 commit comments

Comments
 (0)