Skip to content

Commit 5a11a72

Browse files
committed
Only keep python portion around if there's an alias class
1 parent 2b93ea7 commit 5a11a72

File tree

5 files changed

+56
-1
lines changed

5 files changed

+56
-1
lines changed

include/pybind11/cast.h

+8
Original file line numberDiff line numberDiff line change
@@ -1547,6 +1547,14 @@ struct holder_retriever<std::shared_ptr<T>> {
15471547
};
15481548

15491549
static auto get_derivative_holder(const value_and_holder &v_h) -> std::shared_ptr<T> {
1550+
// If there's no trampoline class, nothing special needed
1551+
if (!v_h.inst->has_alias) {
1552+
return v_h.template holder<std::shared_ptr<T>>();
1553+
}
1554+
1555+
// If there's a trampoline class, ensure the python side of the object doesn't
1556+
// die until the C++ portion also dies
1557+
//
15501558
// The shared_ptr is always given to C++ code, so construct a new shared_ptr
15511559
// that is given a custom deleter. The custom deleter increments the python
15521560
// reference count to bind the python instance lifetime with the lifetime

include/pybind11/detail/common.h

+2
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,8 @@ struct instance {
475475
bool simple_instance_registered : 1;
476476
/// If true, get_internals().patients has an entry for this object
477477
bool has_patients : 1;
478+
/// If true, created with an associated alias class (set via `init_instance`)
479+
bool has_alias : 1;
478480

479481
/// Initializes all of the above type/values/holders data (but not the instance values themselves)
480482
void allocate_layout();

include/pybind11/pybind11.h

+11-1
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,11 @@ class class_ : public detail::generic_type {
12821282
record.type_size = sizeof(conditional_t<has_alias, type_alias, type>);
12831283
record.type_align = alignof(conditional_t<has_alias, type_alias, type>&);
12841284
record.holder_size = sizeof(holder_type);
1285-
record.init_instance = init_instance;
1285+
if (has_alias) {
1286+
record.init_instance = init_alias_instance;
1287+
} else {
1288+
record.init_instance = init_instance;
1289+
}
12861290
record.dealloc = dealloc;
12871291
record.default_holder = detail::is_instantiation<std::unique_ptr, holder_type>::value;
12881292

@@ -1555,6 +1559,12 @@ class class_ : public detail::generic_type {
15551559
init_holder(inst, v_h, (const holder_type *) holder_ptr, v_h.value_ptr<type>());
15561560
}
15571561

1562+
/// Sets the `has_alias` flag in the instance and calls init_instance
1563+
static void init_alias_instance(detail::instance *inst, const void *holder_ptr) {
1564+
inst->has_alias = true;
1565+
init_instance(inst, holder_ptr);
1566+
}
1567+
15581568
/// Deallocates an instance; via holder, if constructed; otherwise via operator delete.
15591569
static void dealloc(detail::value_and_holder &v_h) {
15601570
// We could be deallocating because we are cleaning up after a Python exception.

tests/test_smart_ptr.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -430,4 +430,19 @@ TEST_SUBMODULE(smart_ptr, m) {
430430
.def("set_object", &SpBaseTester::set_object)
431431
.def("is_base_used", &SpBaseTester::is_base_used)
432432
.def_readwrite("obj", &SpBaseTester::m_obj);
433+
434+
// For testing that a C++ class without an alias does not retain the python
435+
// portion of the object
436+
struct SpGoAway {};
437+
438+
struct SpGoAwayTester {
439+
std::shared_ptr<SpGoAway> m_obj;
440+
};
441+
442+
py::class_<SpGoAway, std::shared_ptr<SpGoAway>>(m, "SpGoAway")
443+
.def(py::init<>());
444+
445+
py::class_<SpGoAwayTester>(m, "SpGoAwayTester")
446+
.def(py::init<>())
447+
.def_readwrite("obj", &SpGoAwayTester::m_obj);
433448
}

tests/test_smart_ptr.py

+20
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,23 @@ def test_shared_ptr_arg_identity():
378378
tester.set_object(None)
379379
pytest.gc_collect()
380380
assert objref() is None
381+
382+
383+
def test_shared_ptr_goaway():
384+
import weakref
385+
386+
tester = m.SpGoAwayTester()
387+
388+
obj = m.SpGoAway()
389+
objref = weakref.ref(obj)
390+
391+
assert tester.obj is None
392+
393+
tester.obj = obj
394+
del obj
395+
pytest.gc_collect()
396+
397+
# python reference is no longer around
398+
assert objref() is None
399+
# C++ reference is still around
400+
assert tester.obj is not None

0 commit comments

Comments
 (0)