Skip to content

Commit c323abc

Browse files
committed
This commit introduces multiple changes:
- Greedy hill-climbing and MMHC now accepts an argument of FactorType blacklist. - BayesianNetworkType.data_default_node_type() now returns a list of FactorType indicating the priority of each FactorType for each data type. - BayesianNetworkBase.set_unknown_node_types() now accepts an argument of FactorType blacklist. - Change HeterogeneousBN constructor and HeterogeneousBN.default_node_types() to accept lists of default FactorType. - Fixed an error in the use of the patience parameter. - Improve the validation of objects returned from Python extensions.
1 parent d91d65d commit c323abc

30 files changed

+852
-487
lines changed

docs/source/extending.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ of a :class:`SemiparametricBN <pybnesian.SemiparametricBN>` using an
435435
>>> df = generate_sample_data(300)
436436
>>> df_test = generate_sample_data(20, seed=1)
437437
>>> # Create an heterogeneous with "MyLG" factors as default.
438-
>>> het = HeterogeneousBN(MyLGType(), ["a", "b", "c", "d"], [("a", "c")])
438+
>>> het = HeterogeneousBN([MyLGType()], ["a", "b", "c", "d"], [("a", "c")])
439439
>>> het.set_node_type("a", CKDEType())
440440
>>> het.fit(df)
441441
>>> # Create a SemiparametricBN
@@ -474,9 +474,9 @@ different default factor types for different data types. For example, we can mix
474474
>>> df = generate_hybrid_sample_data(20)
475475
>>> # Create an heterogeneous with "MyLG" factors as default for continuous data and
476476
>>> # "DiscreteFactorType" for categorical data.
477-
>>> het = HeterogeneousBN({pa.float64(): MyLGType(),
478-
... pa.float32(): MyLGType(),
479-
... pa.dictionary(pa.int8(), pa.utf8()): DiscreteFactorType()},
477+
>>> het = HeterogeneousBN({pa.float64(): [MyLGType()],
478+
... pa.float32(): [MyLGType()],
479+
... pa.dictionary(pa.int8(), pa.utf8()): [DiscreteFactorType()]},
480480
... ["a", "b", "c", "d"],
481481
... [("a", "c")])
482482
>>> het.set_node_type("a", CKDEType())

pybnesian/factors/arguments.hpp

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#include <pybind11/pybind11.h>
55
#include <factors/factors.hpp>
66
#include <util/hash_utils.hpp>
7+
#include <util/util_types.hpp>
8+
9+
using util::FactorTypeHash, util::FactorTypeEqualTo, util::PairNameType, util::NameFactorTypeHash,
10+
util::NameFactorTypeEqualTo;
711

812
namespace py = pybind11;
913

@@ -131,32 +135,6 @@ class Arguments {
131135
" an Args(...) (or tuple) or a Kwargs(...) (or dict).");
132136
}
133137

134-
struct FactorTypeHash {
135-
size_t operator()(const std::shared_ptr<FactorType>& ft) const { return ft->hash(); }
136-
};
137-
138-
struct FactorTypeEqualTo {
139-
bool operator()(const std::shared_ptr<FactorType>& lhs, const std::shared_ptr<FactorType>& rhs) const {
140-
return *lhs == *rhs;
141-
}
142-
};
143-
144-
using PairNameType = std::pair<std::string, std::shared_ptr<FactorType>>;
145-
146-
struct NameFactorTypeHash {
147-
size_t operator()(const PairNameType& p) const {
148-
size_t h = std::hash<std::string>{}(p.first);
149-
util::hash_combine(h, p.second->hash());
150-
return h;
151-
}
152-
};
153-
154-
struct NameFactorTypeEqualTo {
155-
bool operator()(const PairNameType& lhs, const PairNameType& rhs) const {
156-
return lhs.first == rhs.first && *lhs.second == *rhs.second;
157-
}
158-
};
159-
160138
std::unordered_map<std::string, std::pair<py::args, py::kwargs>> m_name_args;
161139
std::unordered_map<std::shared_ptr<FactorType>, std::pair<py::args, py::kwargs>, FactorTypeHash, FactorTypeEqualTo>
162140
m_type_args;

pybnesian/factors/factors.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ class FactorType {
4747
return f;
4848
}
4949

50+
static std::vector<std::shared_ptr<FactorType>> keep_vector_python_alive(
51+
std::vector<std::shared_ptr<FactorType>>& v) {
52+
std::vector<std::shared_ptr<FactorType>> fv;
53+
fv.reserve(v.size());
54+
55+
for (auto& f : v) {
56+
fv.push_back(FactorType::keep_python_alive(f));
57+
}
58+
59+
return fv;
60+
}
61+
5062
virtual std::shared_ptr<Factor> new_factor(const BayesianNetworkBase&,
5163
const std::string&,
5264
const std::vector<std::string>&,

pybnesian/learning/algorithms/dmmhc.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ std::shared_ptr<DynamicBayesianNetworkBase> DMMHC::estimate(const DynamicIndepen
6868
ArcStringVector arc_whitelist;
6969
EdgeStringVector edge_blacklist;
7070
EdgeStringVector edge_whitelist;
71+
FactorTypeVector type_blacklist;
7172
FactorTypeVector type_whitelist;
7273

7374
auto g0 = mmhc.estimate(static_tests,
@@ -79,6 +80,7 @@ std::shared_ptr<DynamicBayesianNetworkBase> DMMHC::estimate(const DynamicIndepen
7980
arc_whitelist,
8081
edge_blacklist,
8182
edge_whitelist,
83+
type_blacklist,
8284
type_whitelist,
8385
static_callback,
8486
max_indegree,
@@ -102,6 +104,7 @@ std::shared_ptr<DynamicBayesianNetworkBase> DMMHC::estimate(const DynamicIndepen
102104
arc_whitelist,
103105
edge_blacklist,
104106
edge_whitelist,
107+
type_blacklist,
105108
type_whitelist,
106109
transition_callback,
107110
max_indegree,
Lines changed: 2 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <learning/algorithms/hillclimbing.hpp>
2-
#include <util/validate_whitelists.hpp>
32
#include <util/validate_options.hpp>
43
#include <dataset/dataset.hpp>
54
#include <models/BayesianNetwork.hpp>
@@ -31,6 +30,7 @@ std::shared_ptr<BayesianNetworkBase> hc(const DataFrame& df,
3130
const std::optional<std::vector<std::string>>& operators_str,
3231
const ArcStringVector& arc_blacklist,
3332
const ArcStringVector& arc_whitelist,
33+
const FactorTypeVector& type_blacklist,
3434
const FactorTypeVector& type_whitelist,
3535
const std::shared_ptr<Callback> callback,
3636
int max_indegree,
@@ -79,6 +79,7 @@ std::shared_ptr<BayesianNetworkBase> hc(const DataFrame& df,
7979
*start_model,
8080
arc_blacklist,
8181
arc_whitelist,
82+
type_blacklist,
8283
type_whitelist,
8384
callback,
8485
max_indegree,
@@ -88,95 +89,4 @@ std::shared_ptr<BayesianNetworkBase> hc(const DataFrame& df,
8889
verbose);
8990
}
9091

91-
template <typename T>
92-
std::shared_ptr<T> estimate_checks(OperatorSet& op_set,
93-
Score& score,
94-
const T& start,
95-
const ArcStringVector& arc_blacklist,
96-
const ArcStringVector& arc_whitelist,
97-
const FactorTypeVector& type_whitelist,
98-
const std::shared_ptr<Callback> callback,
99-
int max_indegree,
100-
int max_iters,
101-
double epsilon,
102-
int patience,
103-
int verbose) {
104-
if (!score.compatible_bn(start)) {
105-
throw std::invalid_argument("BayesianNetwork is not compatible with the score.");
106-
}
107-
108-
util::validate_restrictions(start, arc_blacklist, arc_whitelist);
109-
110-
if (auto validated_score = dynamic_cast<ValidatedScore*>(&score)) {
111-
return estimate_validation_hc(op_set,
112-
*validated_score,
113-
start,
114-
arc_blacklist,
115-
arc_whitelist,
116-
type_whitelist,
117-
callback,
118-
max_indegree,
119-
max_iters,
120-
epsilon,
121-
patience,
122-
verbose);
123-
} else {
124-
return estimate_hc(
125-
op_set, score, start, arc_blacklist, arc_whitelist, callback, max_indegree, max_iters, epsilon, verbose);
126-
}
127-
}
128-
129-
std::shared_ptr<BayesianNetworkBase> GreedyHillClimbing::estimate(OperatorSet& op_set,
130-
Score& score,
131-
const BayesianNetworkBase& start,
132-
const ArcStringVector& arc_blacklist,
133-
const ArcStringVector& arc_whitelist,
134-
const FactorTypeVector& type_whitelist,
135-
const std::shared_ptr<Callback> callback,
136-
int max_indegree,
137-
int max_iters,
138-
double epsilon,
139-
int patience,
140-
int verbose) {
141-
return estimate_checks(op_set,
142-
score,
143-
start,
144-
arc_blacklist,
145-
arc_whitelist,
146-
type_whitelist,
147-
callback,
148-
max_indegree,
149-
max_iters,
150-
epsilon,
151-
patience,
152-
verbose);
153-
}
154-
155-
std::shared_ptr<ConditionalBayesianNetworkBase> GreedyHillClimbing::estimate(
156-
OperatorSet& op_set,
157-
Score& score,
158-
const ConditionalBayesianNetworkBase& start,
159-
const ArcStringVector& arc_blacklist,
160-
const ArcStringVector& arc_whitelist,
161-
const FactorTypeVector& type_whitelist,
162-
const std::shared_ptr<Callback> callback,
163-
int max_indegree,
164-
int max_iters,
165-
double epsilon,
166-
int patience,
167-
int verbose) {
168-
return estimate_checks(op_set,
169-
score,
170-
start,
171-
arc_blacklist,
172-
arc_whitelist,
173-
type_whitelist,
174-
callback,
175-
max_indegree,
176-
max_iters,
177-
epsilon,
178-
patience,
179-
verbose);
180-
}
181-
18292
} // namespace learning::algorithms

0 commit comments

Comments
 (0)