-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDBN_tutorial.cpp
146 lines (123 loc) · 4.35 KB
/
DBN_tutorial.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include <iostream>
#include "smile.h"
using namespace std;
void CreateDBN(void);
void InferenceDBN(void);
void UnrollDBN(void);
int main()
{
CreateDBN();
InferenceDBN();
UnrollDBN();
return DSL_OKAY;
}
void CreateDBN(void)
{
// Initialize network and nodes.
DSL_network theDBN;
int rain = theDBN.AddNode(DSL_CPT ,"Rain");
int umbrella = theDBN.AddNode(DSL_CPT ,"Umbrella");
int area = theDBN.AddNode(DSL_CPT ,"Area");
// Create and add statenames to [rain], [umbrella].
DSL_idArray stateNames;
stateNames.Add("True");
stateNames.Add("False");
theDBN.GetNode(rain)->Definition()->SetNumberOfOutcomes(stateNames);
theDBN.GetNode(umbrella)->Definition()->SetNumberOfOutcomes(stateNames);
// Create and add statenames to [area].
stateNames.CleanUp();
stateNames.Add("Pittsburgh");
stateNames.Add("Sahara");
theDBN.GetNode(area)->Definition()->SetNumberOfOutcomes(stateNames);
// Add non-temporal arcs.
theDBN.AddArc(area, rain);
theDBN.AddArc(rain, umbrella);
// Add non-temporal probabilities to nodes.
DSL_doubleArray theProbs;
// Add probabilities to [area].
theProbs.SetSize(2);
theProbs[0] = 0.5;
theProbs[1] = 0.5;
theDBN.GetNode(area)->Definition()->SetDefinition(theProbs);
// Add probabilities to the initial CPT of [rain].
theProbs.SetSize(4);
theProbs[0] = 0.7;
theProbs[1] = 0.3;
theProbs[2] = 0.01;
theProbs[3] = 0.99;
theDBN.GetNode(rain)->Definition()->SetDefinition(theProbs);
// Add probabilities to [umbrella].
theProbs[0] = 0.9;
theProbs[1] = 0.1;
theProbs[2] = 0.2;
theProbs[3] = 0.8;
theDBN.GetNode(umbrella)->Definition()->SetDefinition(theProbs);
// Set temporal types.
theDBN.SetTemporalType(umbrella , dsl_plateNode);
theDBN.SetTemporalType(rain, dsl_plateNode);
// Add temporal arc.
theDBN.AddTemporalArc(rain, rain, 1);
// Add temporal probabilities to the first -order CPT of [rain].
theProbs.SetSize(8);
theProbs[0] = 0.7;
theProbs[1] = 0.3;
theProbs[2] = 0.3;
theProbs[3] = 0.7;
theProbs[4] = 0.001;
theProbs[5] = 0.999;
theProbs[6] = 0.01;
theProbs[7] = 0.99;
((DSL_cpt*)theDBN.GetNode(rain)->Definition())->SetTemporalProbabilities(1, theProbs);
// Write the DBN to a file.
theDBN.WriteFile("dbn.xdsl");
}
void InferenceDBN(void)
{
DSL_network theDBN;
theDBN.ReadFile("dbn.xdsl");
// Obtain the node handles.
int rain = theDBN.FindNode("Rain");
int umbrella = theDBN.FindNode("Umbrella");
int area = theDBN.FindNode("Area");
// Perform inference over a period of 8 days.
theDBN.SetNumberOfSlices(8);
// Set the evidence of the DBN.
theDBN.GetNode(umbrella)->Value()->SetTemporalEvidence(0,1);
theDBN.GetNode(umbrella)->Value()->SetTemporalEvidence(1,1);
theDBN.GetNode(umbrella)->Value()->SetTemporalEvidence(2,1);
theDBN.GetNode(umbrella)->Value()->SetTemporalEvidence(3,0);
theDBN.GetNode(umbrella)->Value()->SetTemporalEvidence(4,1);
theDBN.GetNode(umbrella)->Value()->SetTemporalEvidence(5,1);
theDBN.GetNode(umbrella)->Value()->SetTemporalEvidence(6,1);
// Do inference.
theDBN.UpdateBeliefs();
// Get beliefs.
DSL_Dmatrix abeliefs;
DSL_Dmatrix rbeliefs;
abeliefs = *theDBN.GetNode(area)->Value()->GetMatrix();
rbeliefs = *theDBN.GetNode(rain)->Value()->GetMatrix();
// Print beliefs.
DSL_idArray *stateNames;
stateNames = theDBN.GetNode(area)->Definition()->GetOutcomesNames();
cout << "Beliefs of [area]" << endl;
cout << " " << (*stateNames)[0] << "\t" << abeliefs[0] << endl;
cout << " " << (*stateNames)[1] << "\t" << abeliefs[1] << endl;
cout << endl;
stateNames = theDBN.GetNode(rain)->Definition()->GetOutcomesNames();
cout << "Beliefs of [rain] tomorrow" << endl;
cout << " " << (*stateNames)[0] << "\t\t" << rbeliefs[0] << endl;
cout << " " << (*stateNames)[1] << "\t\t" << rbeliefs[1] << endl;
cout << endl;
}
void UnrollDBN(void)
{
DSL_network theDBN;
theDBN.ReadFile("dbn.xdsl");
// Unroll DBN for a period of 8 days.
theDBN.SetNumberOfSlices(8);
// Save unrolled DBN to a file.
DSL_network unrolled;
std::vector< int > dontcare;
theDBN.UnrollNetwork(unrolled,dontcare);
unrolled.WriteFile("dbn_unrolled_8.xdsl");
}