@@ -52,7 +52,8 @@ class PyFactorType : public FactorType {
52
52
53
53
try {
54
54
auto f = o.cast <std::shared_ptr<Factor>>();
55
- return Factor::keep_python_alive (f);
55
+ Factor::keep_python_alive (f);
56
+ return f;
56
57
} catch (py::cast_error& e) {
57
58
throw std::runtime_error (" The returned object of FactorType::new_factor is not a Factor." );
58
59
}
@@ -78,7 +79,8 @@ class PyFactorType : public FactorType {
78
79
79
80
try {
80
81
auto f = o.cast <std::shared_ptr<Factor>>();
81
- return Factor::keep_python_alive (f);
82
+ Factor::keep_python_alive (f);
83
+ return f;
82
84
} catch (py::cast_error& e) {
83
85
throw std::runtime_error (" The returned object of FactorType::new_factor is not a Factor." );
84
86
}
@@ -160,7 +162,7 @@ class PyFactor : public Factor {
160
162
try {
161
163
m_type = o.cast <std::shared_ptr<FactorType>>();
162
164
// Keep the type in the class member, so type_ref() can return a valid reference.
163
- m_type = FactorType::keep_python_alive (m_type);
165
+ FactorType::keep_python_alive (m_type);
164
166
return m_type;
165
167
} catch (py::cast_error& e) {
166
168
throw std::runtime_error (" The returned object of Factor::type is not a FactorType." );
@@ -735,10 +737,20 @@ Removes the assignment for the ``variable``.
735
737
736
738
py::class_<HCKDE, Factor, std::shared_ptr<HCKDE>>(root, " HCKDE" )
737
739
.def (py::init<std::string, std::vector<std::string>>())
738
- .def (py::init<std::string, std::vector<std::string>, std::shared_ptr<BandwidthSelector>>())
739
- .def (py::init<std::string,
740
- std::vector<std::string>,
741
- std::unordered_map<Assignment, std::tuple<std::shared_ptr<BandwidthSelector>>, AssignmentHash>>())
740
+ .def (py::init<>([](std::string variable,
741
+ std::vector<std::string> evidence,
742
+ std::shared_ptr<BandwidthSelector> bandwidth_selector) {
743
+ return HCKDE (variable, evidence, BandwidthSelector::keep_python_alive (bandwidth_selector));
744
+ }), py::arg (" variable" ), py::arg (" evidence" ), py::arg (" bandwidth_selector" ))
745
+ .def (py::init<>([](std::string variable,
746
+ std::vector<std::string> evidence,
747
+ std::unordered_map<Assignment, std::tuple<std::shared_ptr<BandwidthSelector>>, AssignmentHash> args) {
748
+ for (auto & arg : args) {
749
+ BandwidthSelector::keep_python_alive (std::get<0 >(arg.second ));
750
+ }
751
+
752
+ return HCKDE (variable, evidence, args);
753
+ }), py::arg (" variable" ), py::arg (" evidence" ), py::arg (" bandwidth_selector" ))
742
754
.def (" conditional_factor" , &HCKDE::conditional_factor, py::return_value_policy::reference_internal)
743
755
.def (py::pickle ([](const HCKDE& self) { return self.__getstate__ (); },
744
756
[](py::tuple t) { return HCKDE::__setstate__ (t); }));
0 commit comments