-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbs.h
53 lines (45 loc) · 1.22 KB
/
bs.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#ifndef DSL_BS_H
#define DSL_BS_H
// {{SMILE_PUBLIC_HEADER}}
#include <vector>
#include "dataset.h"
#include "bkgndknowledge.h"
class DSL_progress;
class DSL_bsEvaluator
{
public:
virtual int Evaluate(int iteration, double bsScore, double bestScore,
DSL_network &net, const DSL_dataset &ds, const std::vector<DSL_datasetMatch> &matching,
DSL_progress *progress,
double &outputScore) = 0;
};
class DSL_bs
{
public:
DSL_bs()
{
maxParents = 5;
maxSearchTime = 0;
nrIteration = 20;
linkProbability = 0.1;
priorLinkProbability = 0.001;
priorSampleSize = 50;
seed = 0;
ThickThinning = false;
}
virtual ~DSL_bs() {}
virtual int Learn(const DSL_dataset &ds_, DSL_network &net, DSL_progress *progress = NULL, DSL_bsEvaluator *eval = NULL, double *bestScore = NULL, int *bestIteration = NULL) const;
int maxParents;
int maxSearchTime;
int nrIteration;
double linkProbability;
double priorLinkProbability;
int priorSampleSize;
int seed;
bool ThickThinning;
DSL_bkgndKnowledge bkk;
protected:
int PreChecks(const DSL_dataset &ds_) const;
void PrepareMask(const DSL_dataset &ds, std::vector< std::vector<char> > &mask) const;
};
#endif