Skip to content

Commit caa2277

Browse files
committed
Added support for Mapping, Set and Sequence derived from collections.abc.
1 parent 80a9086 commit caa2277

File tree

3 files changed

+145
-2
lines changed

3 files changed

+145
-2
lines changed

include/pybind11/stl.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,14 @@ struct set_caster {
173173
public:
174174
bool load(handle src, bool convert) {
175175
if (!PyObjectTypeIsConvertibleToStdSet(src.ptr())) {
176-
return false;
176+
if (!convert) {
177+
return false;
178+
}
179+
if (!(isinstance(src, module_::import("collections.abc").attr("Set"))
180+
&& hasattr(src, "__contains__") && hasattr(src, "__iter__")
181+
&& hasattr(src, "__len__"))) {
182+
return false;
183+
}
177184
}
178185
if (isinstance<anyset>(src)) {
179186
value.clear();
@@ -237,7 +244,14 @@ struct map_caster {
237244
public:
238245
bool load(handle src, bool convert) {
239246
if (!PyObjectTypeIsConvertibleToStdMap(src.ptr())) {
240-
return false;
247+
if (!convert) {
248+
return false;
249+
}
250+
if (!(isinstance(src, module_::import("collections.abc").attr("Mapping"))
251+
&& hasattr(src, "__getitem__") && hasattr(src, "__iter__")
252+
&& hasattr(src, "__len__"))) {
253+
return false;
254+
}
241255
}
242256
if (isinstance<dict>(src)) {
243257
return convert_elements(reinterpret_borrow<dict>(src), convert);

tests/test_stl.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -648,4 +648,7 @@ TEST_SUBMODULE(stl, m) {
648648
}
649649
return zum;
650650
});
651+
m.def("roundtrip_std_vector_int", [](const std::vector<int> &v) { return v; });
652+
m.def("roundtrip_std_map_str_int", [](const std::map<std::string, int> &m) { return m; });
653+
m.def("roundtrip_std_set_int", [](const std::set<int> &s) { return s; });
651654
}

tests/test_stl.py

+126
Original file line numberDiff line numberDiff line change
@@ -576,3 +576,129 @@ def gen_invalid():
576576
with pytest.raises(expected_exception):
577577
m.pass_std_map_int(FakePyMappingGenObj(gen_obj))
578578
assert not tuple(gen_obj)
579+
580+
581+
def test_sequence_caster_protocol(doc):
582+
from collections.abc import Sequence
583+
584+
class SequenceLike(Sequence):
585+
def __init__(self, *args):
586+
self.data = tuple(args)
587+
588+
def __len__(self):
589+
return len(self.data)
590+
591+
def __getitem__(self, index):
592+
return self.data[index]
593+
594+
class FakeSequenceLike:
595+
def __init__(self, *args):
596+
self.data = tuple(args)
597+
598+
def __len__(self):
599+
return len(self.data)
600+
601+
def __getitem__(self, index):
602+
return self.data[index]
603+
604+
assert (
605+
doc(m.roundtrip_std_vector_int)
606+
== "roundtrip_std_vector_int(arg0: collections.abc.Sequence[int]) -> list[int]"
607+
)
608+
assert m.roundtrip_std_vector_int([1, 2, 3]) == [1, 2, 3]
609+
assert m.roundtrip_std_vector_int((1, 2, 3)) == [1, 2, 3]
610+
assert m.roundtrip_std_vector_int(SequenceLike(1, 2, 3)) == [1, 2, 3]
611+
assert m.roundtrip_std_vector_int(FakeSequenceLike(1, 2, 3)) == [1, 2, 3]
612+
assert m.roundtrip_std_vector_int([]) == []
613+
assert m.roundtrip_std_vector_int(()) == []
614+
assert m.roundtrip_std_vector_int(FakeSequenceLike()) == []
615+
616+
617+
def test_mapping_caster_protocol(doc):
618+
from collections.abc import Mapping
619+
620+
class MappingLike(Mapping):
621+
def __init__(self, **kwargs):
622+
self.data = dict(kwargs)
623+
624+
def __len__(self):
625+
return len(self.data)
626+
627+
def __getitem__(self, key):
628+
return self.data[key]
629+
630+
def __iter__(self):
631+
yield from self.data
632+
633+
class FakeMappingLike:
634+
def __init__(self, **kwargs):
635+
self.data = dict(kwargs)
636+
637+
def __len__(self):
638+
return len(self.data)
639+
640+
def __getitem__(self, key):
641+
return self.data[key]
642+
643+
def __iter__(self):
644+
yield from self.data
645+
646+
assert (
647+
doc(m.roundtrip_std_map_str_int)
648+
== "roundtrip_std_map_str_int(arg0: collections.abc.Mapping[str, int]) -> dict[str, int]"
649+
)
650+
assert m.roundtrip_std_map_str_int({"a": 1, "b": 2, "c": 3}) == {
651+
"a": 1,
652+
"b": 2,
653+
"c": 3,
654+
}
655+
assert m.roundtrip_std_map_str_int(MappingLike(a=1, b=2, c=3)) == {
656+
"a": 1,
657+
"b": 2,
658+
"c": 3,
659+
}
660+
assert m.roundtrip_std_map_str_int({}) == {}
661+
assert m.roundtrip_std_map_str_int(MappingLike()) == {}
662+
with pytest.raises(TypeError):
663+
m.roundtrip_std_map_str_int(FakeMappingLike(a=1, b=2, c=3))
664+
665+
666+
def test_set_caster_protocol(doc):
667+
from collections.abc import Set
668+
669+
class SetLike(Set):
670+
def __init__(self, *args):
671+
self.data = set(args)
672+
673+
def __len__(self):
674+
return len(self.data)
675+
676+
def __contains__(self, item):
677+
return item in self.data
678+
679+
def __iter__(self):
680+
yield from self.data
681+
682+
class FakeSetLike:
683+
def __init__(self, *args):
684+
self.data = set(args)
685+
686+
def __len__(self):
687+
return len(self.data)
688+
689+
def __contains__(self, item):
690+
return item in self.data
691+
692+
def __iter__(self):
693+
yield from self.data
694+
695+
assert (
696+
doc(m.roundtrip_std_set_int)
697+
== "roundtrip_std_set_int(arg0: collections.abc.Set[int]) -> set[int]"
698+
)
699+
assert m.roundtrip_std_set_int({1, 2, 3}) == {1, 2, 3}
700+
assert m.roundtrip_std_set_int(SetLike(1, 2, 3)) == {1, 2, 3}
701+
assert m.roundtrip_std_set_int(set()) == set()
702+
assert m.roundtrip_std_set_int(SetLike()) == set()
703+
with pytest.raises(TypeError):
704+
m.roundtrip_std_set_int(FakeSetLike(1, 2, 3))

0 commit comments

Comments
 (0)