Skip to content

Commit d070afc

Browse files
committed
Fixes the removal of some Python derived objects produced when a Python derived type is passed as argument of some functions.
This commit solves this problem by ensuring the Python side is kept alive. Related to: pybind/pybind11#1333
1 parent e473028 commit d070afc

12 files changed

+359
-66
lines changed

pybnesian/factors/factors.hpp

+55-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,18 @@ class FactorType {
3636

3737
virtual bool is_python_derived() const { return false; }
3838

39-
static std::shared_ptr<FactorType> keep_python_alive(std::shared_ptr<FactorType>& f) {
39+
static std::shared_ptr<FactorType>& keep_python_alive(std::shared_ptr<FactorType>& f) {
40+
if (f && f->is_python_derived()) {
41+
auto o = py::cast(f);
42+
auto keep_python_state_alive = std::make_shared<py::object>(o);
43+
auto ptr = o.cast<FactorType*>();
44+
f = std::shared_ptr<FactorType>(keep_python_state_alive, ptr);
45+
}
46+
47+
return f;
48+
}
49+
50+
static std::shared_ptr<FactorType> keep_python_alive(const std::shared_ptr<FactorType>& f) {
4051
if (f && f->is_python_derived()) {
4152
auto o = py::cast(f);
4253
auto keep_python_state_alive = std::make_shared<py::object>(o);
@@ -47,12 +58,21 @@ class FactorType {
4758
return f;
4859
}
4960

50-
static std::vector<std::shared_ptr<FactorType>> keep_vector_python_alive(
61+
static std::vector<std::shared_ptr<FactorType>>& keep_vector_python_alive(
5162
std::vector<std::shared_ptr<FactorType>>& v) {
63+
for (auto& f : v) {
64+
FactorType::keep_python_alive(f);
65+
}
66+
67+
return v;
68+
}
69+
70+
static std::vector<std::shared_ptr<FactorType>> keep_vector_python_alive(
71+
const std::vector<std::shared_ptr<FactorType>>& v) {
5272
std::vector<std::shared_ptr<FactorType>> fv;
5373
fv.reserve(v.size());
5474

55-
for (auto& f : v) {
75+
for (const auto& f : v) {
5676
fv.push_back(FactorType::keep_python_alive(f));
5777
}
5878

@@ -105,7 +125,18 @@ class Factor {
105125

106126
virtual bool is_python_derived() const { return false; }
107127

108-
static std::shared_ptr<Factor> keep_python_alive(std::shared_ptr<Factor>& f) {
128+
static std::shared_ptr<Factor>& keep_python_alive(std::shared_ptr<Factor>& f) {
129+
if (f && f->is_python_derived()) {
130+
auto o = py::cast(f);
131+
auto keep_python_state_alive = std::make_shared<py::object>(o);
132+
auto ptr = o.cast<Factor*>();
133+
f = std::shared_ptr<Factor>(keep_python_state_alive, ptr);
134+
}
135+
136+
return f;
137+
}
138+
139+
static std::shared_ptr<Factor> keep_python_alive(const std::shared_ptr<Factor>& f) {
109140
if (f && f->is_python_derived()) {
110141
auto o = py::cast(f);
111142
auto keep_python_state_alive = std::make_shared<py::object>(o);
@@ -116,6 +147,26 @@ class Factor {
116147
return f;
117148
}
118149

150+
static std::vector<std::shared_ptr<Factor>>& keep_vector_python_alive(std::vector<std::shared_ptr<Factor>>& v) {
151+
for (auto& f : v) {
152+
Factor::keep_python_alive(f);
153+
}
154+
155+
return v;
156+
}
157+
158+
static std::vector<std::shared_ptr<Factor>> keep_vector_python_alive(
159+
const std::vector<std::shared_ptr<Factor>>& v) {
160+
std::vector<std::shared_ptr<Factor>> fv;
161+
fv.reserve(v.size());
162+
163+
for (const auto& f : v) {
164+
fv.push_back(Factor::keep_python_alive(f));
165+
}
166+
167+
return fv;
168+
}
169+
119170
const std::string& variable() const { return m_variable; }
120171

121172
const std::vector<std::string>& evidence() const { return m_evidence; }

pybnesian/kde/BandwidthSelector.hpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,18 @@ class BandwidthSelector {
1515

1616
virtual bool is_python_derived() const { return false; }
1717

18-
static std::shared_ptr<BandwidthSelector> keep_python_alive(std::shared_ptr<BandwidthSelector>& b) {
18+
static std::shared_ptr<BandwidthSelector>& keep_python_alive(std::shared_ptr<BandwidthSelector>& b) {
19+
if (b && b->is_python_derived()) {
20+
auto o = py::cast(b);
21+
auto keep_python_state_alive = std::make_shared<py::object>(o);
22+
auto ptr = o.cast<BandwidthSelector*>();
23+
b = std::shared_ptr<BandwidthSelector>(keep_python_state_alive, ptr);
24+
}
25+
26+
return b;
27+
}
28+
29+
static std::shared_ptr<BandwidthSelector> keep_python_alive(const std::shared_ptr<BandwidthSelector>& b) {
1930
if (b && b->is_python_derived()) {
2031
auto o = py::cast(b);
2132
auto keep_python_state_alive = std::make_shared<py::object>(o);

pybnesian/learning/operators/operators.hpp

+56-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,18 @@ class Operator {
3838
virtual bool operator==(const Operator& a) const = 0;
3939
bool operator!=(const Operator& a) const { return !(*this == a); }
4040

41-
static std::shared_ptr<Operator> keep_python_alive(std::shared_ptr<Operator>& op) {
41+
static std::shared_ptr<Operator>& keep_python_alive(std::shared_ptr<Operator>& op) {
42+
if (op && op->is_python_derived()) {
43+
auto o = py::cast(op);
44+
auto keep_python_state_alive = std::make_shared<py::object>(o);
45+
auto ptr = o.cast<Operator*>();
46+
op = std::shared_ptr<Operator>(keep_python_state_alive, ptr);
47+
}
48+
49+
return op;
50+
}
51+
52+
static std::shared_ptr<Operator> keep_python_alive(const std::shared_ptr<Operator>& op) {
4253
if (op && op->is_python_derived()) {
4354
auto o = py::cast(op);
4455
auto keep_python_state_alive = std::make_shared<py::object>(o);
@@ -330,6 +341,7 @@ class OperatorSet {
330341
public:
331342
OperatorSet() : m_local_cache(nullptr), m_owns_local_cache(false) {}
332343
virtual ~OperatorSet() {}
344+
virtual bool is_python_derived() const { return false; }
333345
virtual void cache_scores(const BayesianNetworkBase&, const Score&) = 0;
334346
virtual std::shared_ptr<Operator> find_max(const BayesianNetworkBase&) const = 0;
335347
virtual std::shared_ptr<Operator> find_max(const BayesianNetworkBase&, const OperatorTabuSet&) const = 0;
@@ -356,6 +368,49 @@ class OperatorSet {
356368
virtual void set_type_whitelist(const FactorTypeVector&){};
357369
virtual void finished() { m_local_cache = nullptr; }
358370

371+
static std::shared_ptr<OperatorSet>& keep_python_alive(std::shared_ptr<OperatorSet>& op_set) {
372+
if (op_set && op_set->is_python_derived()) {
373+
auto o = py::cast(op_set);
374+
auto keep_python_state_alive = std::make_shared<py::object>(o);
375+
auto ptr = o.cast<OperatorSet*>();
376+
op_set = std::shared_ptr<OperatorSet>(keep_python_state_alive, ptr);
377+
}
378+
379+
return op_set;
380+
}
381+
382+
static std::shared_ptr<OperatorSet> keep_python_alive(const std::shared_ptr<OperatorSet>& op_set) {
383+
if (op_set && op_set->is_python_derived()) {
384+
auto o = py::cast(op_set);
385+
auto keep_python_state_alive = std::make_shared<py::object>(o);
386+
auto ptr = o.cast<OperatorSet*>();
387+
return std::shared_ptr<OperatorSet>(keep_python_state_alive, ptr);
388+
}
389+
390+
return op_set;
391+
}
392+
393+
static std::vector<std::shared_ptr<OperatorSet>>& keep_vector_python_alive(
394+
std::vector<std::shared_ptr<OperatorSet>>& v) {
395+
for (auto& op_set : v) {
396+
OperatorSet::keep_python_alive(op_set);
397+
}
398+
399+
return v;
400+
}
401+
402+
static std::vector<std::shared_ptr<OperatorSet>> keep_vector_python_alive(
403+
const std::vector<std::shared_ptr<OperatorSet>>& v) {
404+
std::vector<std::shared_ptr<OperatorSet>> fv;
405+
fv.reserve(v.size());
406+
407+
for (const auto& op_set : v) {
408+
fv.push_back(OperatorSet::keep_python_alive(op_set));
409+
}
410+
411+
return fv;
412+
}
413+
359414
protected:
360415
bool owns_local_cache() const { return m_owns_local_cache; }
361416

pybnesian/models/BayesianNetwork.hpp

+37-3
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,18 @@ class BayesianNetworkBase : public std::enable_shared_from_this<BayesianNetworkB
118118
}
119119
}
120120

121-
static std::shared_ptr<BayesianNetworkBase> keep_python_alive(std::shared_ptr<BayesianNetworkBase>& m) {
121+
static std::shared_ptr<BayesianNetworkBase>& keep_python_alive(std::shared_ptr<BayesianNetworkBase>& m) {
122+
if (m && m->is_python_derived()) {
123+
auto o = py::cast(m);
124+
auto keep_python_state_alive = std::make_shared<py::object>(o);
125+
auto ptr = o.cast<BayesianNetworkBase*>();
126+
m = std::shared_ptr<BayesianNetworkBase>(keep_python_state_alive, ptr);
127+
}
128+
129+
return m;
130+
}
131+
132+
static std::shared_ptr<BayesianNetworkBase> keep_python_alive(const std::shared_ptr<BayesianNetworkBase>& m) {
122133
if (m && m->is_python_derived()) {
123134
auto o = py::cast(m);
124135
auto keep_python_state_alive = std::make_shared<py::object>(o);
@@ -182,8 +193,20 @@ class ConditionalBayesianNetworkBase : public BayesianNetworkBase {
182193
}
183194
}
184195

185-
static std::shared_ptr<ConditionalBayesianNetworkBase> keep_python_alive(
196+
static std::shared_ptr<ConditionalBayesianNetworkBase>& keep_python_alive(
186197
std::shared_ptr<ConditionalBayesianNetworkBase>& m) {
198+
if (m && m->is_python_derived()) {
199+
auto o = py::cast(m);
200+
auto keep_python_state_alive = std::make_shared<py::object>(o);
201+
auto ptr = o.cast<ConditionalBayesianNetworkBase*>();
202+
m = std::shared_ptr<ConditionalBayesianNetworkBase>(keep_python_state_alive, ptr);
203+
}
204+
205+
return m;
206+
}
207+
208+
static std::shared_ptr<ConditionalBayesianNetworkBase> keep_python_alive(
209+
const std::shared_ptr<ConditionalBayesianNetworkBase>& m) {
187210
if (m && m->is_python_derived()) {
188211
auto o = py::cast(m);
189212
auto keep_python_state_alive = std::make_shared<py::object>(o);
@@ -208,7 +231,18 @@ class BayesianNetworkType {
208231
virtual std::shared_ptr<ConditionalBayesianNetworkBase> new_cbn(
209232
const std::vector<std::string>& nodes, const std::vector<std::string>& interface_nodes) const = 0;
210233

211-
static std::shared_ptr<BayesianNetworkType> keep_python_alive(std::shared_ptr<BayesianNetworkType>& s) {
234+
static std::shared_ptr<BayesianNetworkType>& keep_python_alive(std::shared_ptr<BayesianNetworkType>& s) {
235+
if (s && s->is_python_derived()) {
236+
auto o = py::cast(s);
237+
auto keep_python_state_alive = std::make_shared<py::object>(o);
238+
auto ptr = o.cast<BayesianNetworkType*>();
239+
s = std::shared_ptr<BayesianNetworkType>(keep_python_state_alive, ptr);
240+
}
241+
242+
return s;
243+
}
244+
245+
static std::shared_ptr<BayesianNetworkType> keep_python_alive(const std::shared_ptr<BayesianNetworkType>& s) {
212246
if (s && s->is_python_derived()) {
213247
auto o = py::cast(s);
214248
auto keep_python_state_alive = std::make_shared<py::object>(o);

pybnesian/models/HeterogeneousBN.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
namespace models {
44

5-
MapDataToFactor keep_MapDataToFactor_alive(MapDataToFactor& m) {
5+
MapDataToFactor& keep_MapDataToFactor_alive(MapDataToFactor& m) {
6+
for (auto& item : m) {
7+
FactorType::keep_vector_python_alive(item.second);
8+
}
9+
10+
return m;
11+
}
12+
13+
MapDataToFactor keep_MapDataToFactor_alive(const MapDataToFactor& m) {
614
MapDataToFactor alive;
715

816
for (auto& item : m) {

pybnesian/models/HeterogeneousBN.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@ class DataTypeEqualTo {
2222
using MapDataToFactor = std::
2323
unordered_map<std::shared_ptr<DataType>, std::vector<std::shared_ptr<FactorType>>, DataTypeHash, DataTypeEqualTo>;
2424

25-
MapDataToFactor keep_MapDataToFactor_alive(MapDataToFactor& m);
25+
MapDataToFactor& keep_MapDataToFactor_alive(MapDataToFactor& m);
26+
MapDataToFactor keep_MapDataToFactor_alive(const MapDataToFactor& m);
2627

2728
class HeterogeneousBNType : public BayesianNetworkType {
2829
public:
2930
HeterogeneousBNType(const HeterogeneousBNType&) = delete;
3031
void operator=(const HeterogeneousBNType&) = delete;
3132

33+
HeterogeneousBNType(HeterogeneousBNType&&) = default;
34+
HeterogeneousBNType& operator=(HeterogeneousBNType&&) = default;
35+
3236
HeterogeneousBNType(std::vector<std::shared_ptr<FactorType>> default_ft)
3337
: m_default_ftype(default_ft), m_default_ftypes(), m_single_default(true) {
3438
if (default_ft.empty()) throw std::invalid_argument("Default factor_type cannot be empty.");

pybnesian/models/HomogeneousBN.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ class HomogeneousBNType : public BayesianNetworkType {
1212
HomogeneousBNType(const HomogeneousBNType&) = delete;
1313
void operator=(const HomogeneousBNType&) = delete;
1414

15+
HomogeneousBNType(HomogeneousBNType&&) = default;
16+
HomogeneousBNType& operator=(HomogeneousBNType&&) = default;
17+
1518
HomogeneousBNType(std::shared_ptr<FactorType> ft) : m_ftype(ft) {
1619
if (ft == nullptr) throw std::invalid_argument("factor_type cannot be null.");
1720

pybnesian/pybindings/pybindings_factors.cpp

+19-7
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class PyFactorType : public FactorType {
5252

5353
try {
5454
auto f = o.cast<std::shared_ptr<Factor>>();
55-
return Factor::keep_python_alive(f);
55+
Factor::keep_python_alive(f);
56+
return f;
5657
} catch (py::cast_error& e) {
5758
throw std::runtime_error("The returned object of FactorType::new_factor is not a Factor.");
5859
}
@@ -78,7 +79,8 @@ class PyFactorType : public FactorType {
7879

7980
try {
8081
auto f = o.cast<std::shared_ptr<Factor>>();
81-
return Factor::keep_python_alive(f);
82+
Factor::keep_python_alive(f);
83+
return f;
8284
} catch (py::cast_error& e) {
8385
throw std::runtime_error("The returned object of FactorType::new_factor is not a Factor.");
8486
}
@@ -160,7 +162,7 @@ class PyFactor : public Factor {
160162
try {
161163
m_type = o.cast<std::shared_ptr<FactorType>>();
162164
// Keep the type in the class member, so type_ref() can return a valid reference.
163-
m_type = FactorType::keep_python_alive(m_type);
165+
FactorType::keep_python_alive(m_type);
164166
return m_type;
165167
} catch (py::cast_error& e) {
166168
throw std::runtime_error("The returned object of Factor::type is not a FactorType.");
@@ -735,10 +737,20 @@ Removes the assignment for the ``variable``.
735737

736738
py::class_<HCKDE, Factor, std::shared_ptr<HCKDE>>(root, "HCKDE")
737739
.def(py::init<std::string, std::vector<std::string>>())
738-
.def(py::init<std::string, std::vector<std::string>, std::shared_ptr<BandwidthSelector>>())
739-
.def(py::init<std::string,
740-
std::vector<std::string>,
741-
std::unordered_map<Assignment, std::tuple<std::shared_ptr<BandwidthSelector>>, AssignmentHash>>())
740+
.def(py::init<>([](std::string variable,
741+
std::vector<std::string> evidence,
742+
std::shared_ptr<BandwidthSelector> bandwidth_selector) {
743+
return HCKDE(variable, evidence, BandwidthSelector::keep_python_alive(bandwidth_selector));
744+
}), py::arg("variable"), py::arg("evidence"), py::arg("bandwidth_selector"))
745+
.def(py::init<>([](std::string variable,
746+
std::vector<std::string> evidence,
747+
std::unordered_map<Assignment, std::tuple<std::shared_ptr<BandwidthSelector>>, AssignmentHash> args) {
748+
for (auto& arg : args) {
749+
BandwidthSelector::keep_python_alive(std::get<0>(arg.second));
750+
}
751+
752+
return HCKDE(variable, evidence, args);
753+
}), py::arg("variable"), py::arg("evidence"), py::arg("bandwidth_selector"))
742754
.def("conditional_factor", &HCKDE::conditional_factor, py::return_value_policy::reference_internal)
743755
.def(py::pickle([](const HCKDE& self) { return self.__getstate__(); },
744756
[](py::tuple t) { return HCKDE::__setstate__(t); }));

0 commit comments

Comments
 (0)