1
+ import numpy as np
2
+ from notes_utilities import randgen , log_sum_exp , normalize_exp , normalize
3
+
4
+ class HMM (object ):
5
+ def __init__ (self , pi , A , B ):
6
+ # p(x_0)
7
+ self .pi = pi
8
+ # p(x_k|x_{k-1})
9
+ self .A = A
10
+ # p(y_k|x_{k})
11
+ self .B = B
12
+ # Number of possible latent states at each time
13
+ self .S = pi .shape [0 ]
14
+ # Number of possible observations at each time
15
+ self .R = B .shape [0 ]
16
+ self .logB = np .log (self .B )
17
+ self .logA = np .log (self .A )
18
+ self .logpi = np .log (self .pi )
19
+
20
+ def set_param (self , pi = None , A = None , B = None ):
21
+ if pi is not None :
22
+ self .pi = pi
23
+ self .logpi = np .log (self .pi )
24
+
25
+ if A is not None :
26
+ self .A = A
27
+ self .logA = np .log (self .A )
28
+
29
+ if B is not None :
30
+ self .B = B
31
+ self .logB = np .log (self .B )
32
+
33
+ @classmethod
34
+ def from_random_parameters (cls , S = 3 , R = 5 ):
35
+ A = np .random .dirichlet (0.7 * np .ones (S ),S ).T
36
+ B = np .random .dirichlet (0.7 * np .ones (R ),S ).T
37
+ pi = np .random .dirichlet (0.7 * np .ones (S )).T
38
+ return cls (pi , A , B )
39
+
40
+ def __str__ (self ):
41
+ s = "Prior:\n " + str (self .pi ) + "\n A:\n " + str (self .A ) + "\n B:\n " + str (self .B )
42
+ return s
43
+
44
+ def __repr__ (self ):
45
+ s = self .__str__ ()
46
+ return s
47
+
48
+ def predict (self , lp ):
49
+ lstar = np .max (lp )
50
+ return lstar + np .log (np .dot (self .A ,np .exp (lp - lstar )))
51
+
52
+ def postdict (self , lp ):
53
+ lstar = np .max (lp )
54
+ return lstar + np .log (np .dot (np .exp (lp - lstar ), self .A ))
55
+
56
+ def predict_maxm (self , lp ):
57
+ return np .max (self .logA + lp , axis = 1 )
58
+
59
+ def postdict_maxm (self , lp ):
60
+ return np .max (self .logA .T + lp , axis = 1 )
61
+
62
+ def update (self , y , lp ):
63
+ return self .logB [y ,:] + lp if not np .isnan (y ) else lp
64
+
65
+ def generate_sequence (self , T = 10 ):
66
+ # T: Number of steps
67
+
68
+ x = np .zeros (T , int )
69
+ y = np .zeros (T , int )
70
+
71
+ for t in range (T ):
72
+ if t == 0 :
73
+ x [t ] = randgen (self .pi )
74
+ else :
75
+ x [t ] = randgen (self .A [:,x [t - 1 ]])
76
+ y [t ] = randgen (self .B [:,x [t ]])
77
+
78
+ return y , x
79
+
80
+ def forward (self , y , maxm = False ):
81
+ T = len (y )
82
+
83
+ # Forward Pass
84
+
85
+ # Python indices start from zero so
86
+ # log \alpha_{k|k} will be in log_alpha[:,k-1]
87
+ # log \alpha_{k|k-1} will be in log_alpha_pred[:,k-1]
88
+ log_alpha = np .zeros ((self .S , T ))
89
+ log_alpha_pred = np .zeros ((self .S , T ))
90
+ for k in range (T ):
91
+ if k == 0 :
92
+ log_alpha_pred [:,0 ] = self .logpi
93
+ else :
94
+ if maxm :
95
+ log_alpha_pred [:,k ] = self .predict_maxm (log_alpha [:,k - 1 ])
96
+ else :
97
+ log_alpha_pred [:,k ] = self .predict (log_alpha [:,k - 1 ])
98
+
99
+
100
+ log_alpha [:,k ] = self .update (y [k ], log_alpha_pred [:,k ])
101
+
102
+ return log_alpha , log_alpha_pred
103
+
104
+ def backward (self , y , maxm = False ):
105
+ # Backward Pass
106
+ T = len (y )
107
+ log_beta = np .zeros ((self .S , T ))
108
+ log_beta_post = np .zeros ((self .S , T ))
109
+
110
+ for k in range (T - 1 ,- 1 ,- 1 ):
111
+ if k == T - 1 :
112
+ log_beta_post [:,k ] = np .zeros (self .S )
113
+ else :
114
+ if maxm :
115
+ log_beta_post [:,k ] = self .postdict_maxm (log_beta [:,k + 1 ])
116
+ else :
117
+ log_beta_post [:,k ] = self .postdict (log_beta [:,k + 1 ])
118
+
119
+ log_beta [:,k ] = self .update (y [k ], log_beta_post [:,k ])
120
+
121
+ return log_beta , log_beta_post
122
+
123
+ def forward_backward_smoother (self , y ):
124
+ log_alpha , log_alpha_pred = self .forward (y )
125
+ log_beta , log_beta_post = self .backward (y )
126
+
127
+ log_gamma = log_alpha + log_beta_post
128
+ return log_gamma
129
+
130
+ def viterbi (self , y ):
131
+ T = len (y )
132
+
133
+ # Forward Pass
134
+ log_alpha = np .zeros ((self .S , T ))
135
+ for k in range (T ):
136
+ if k == 0 :
137
+ log_alpha_pred = self .logpi
138
+ else :
139
+ log_alpha_pred = self .predict (log_alpha [:,k - 1 ])
140
+
141
+ log_alpha [:,k ] = self .update (y [k ], log_alpha_pred )
142
+
143
+ xs = list ()
144
+ w = np .argmax (log_alpha [:,- 1 ])
145
+ xs .insert (0 , w )
146
+ for k in range (T - 2 ,- 1 ,- 1 ):
147
+ w = np .argmax (log_alpha [:,k ] + self .logA [w ,:])
148
+ xs .insert (0 , w )
149
+
150
+ return xs
151
+
152
+ def viterbi_maxsum (self , y ):
153
+ '''Vanilla implementation of Viterbi decoding via max-sum'''
154
+ '''This algorithm may fail to find the MAP trajectory as it breaks ties arbitrarily'''
155
+ log_alpha , log_alpha_pred = self .forward (y , maxm = True )
156
+ log_beta , log_beta_post = self .backward (y , maxm = True )
157
+
158
+ log_delta = log_alpha + log_beta_post
159
+ return np .argmax (log_delta , axis = 0 )
160
+
161
+
162
+ def correction_smoother (self , y ):
163
+ # Correction Smoother
164
+
165
+ log_alpha , log_alpha_pred = self .forward (y )
166
+ T = len (y )
167
+
168
+ # For numerical stability, we calculate everything in the log domain
169
+ log_gamma_corr = np .zeros_like (log_alpha )
170
+ log_gamma_corr [:,T - 1 ] = log_alpha [:,T - 1 ]
171
+
172
+ C2 = np .zeros ((self .S , self .S ))
173
+ C3 = np .zeros ((self .R , self .S ))
174
+ C3 [y [- 1 ],:] = normalize_exp (log_alpha [:,T - 1 ])
175
+ for k in range (T - 2 ,- 1 ,- 1 ):
176
+ log_old_pairwise_marginal = log_alpha [:,k ].reshape (1 ,self .S ) + self .logA
177
+ log_old_marginal = self .predict (log_alpha [:,k ])
178
+ log_new_pairwise_marginal = log_old_pairwise_marginal + log_gamma_corr [:,k + 1 ].reshape (self .S ,1 ) - log_old_marginal .reshape (self .S ,1 )
179
+ log_gamma_corr [:,k ] = log_sum_exp (log_new_pairwise_marginal , axis = 0 ).reshape (self .S )
180
+ C2 += normalize_exp (log_new_pairwise_marginal )
181
+ C3 [y [k ],:] += normalize_exp (log_gamma_corr [:,k ])
182
+ C1 = normalize_exp (log_gamma_corr [:,0 ])
183
+ return log_gamma_corr , C1 , C2 , C3
184
+
185
+ def forward_only_SS (self , y , V = None ):
186
+ # Forward only estimation of expected sufficient statistics
187
+ T = len (y )
188
+
189
+ if V is None :
190
+ V1 = np .eye ((self .S ))
191
+ V2 = np .zeros ((self .S ,self .S ,self .S ))
192
+ V3 = np .zeros ((self .R ,self .S ,self .S ))
193
+ else :
194
+ V1 , V2 , V3 = V
195
+
196
+ I_S1S = np .eye (self .S ).reshape ((self .S ,1 ,self .S ))
197
+ I_RR = np .eye (self .R )
198
+
199
+ for k in range (T ):
200
+ if k == 0 :
201
+ log_alpha_pred = self .logpi
202
+ else :
203
+ log_alpha_pred = self .predict (log_alpha )
204
+
205
+ if k > 0 :
206
+ #print(self.S, self.R)
207
+ #print(log_alpha)
208
+ # Calculate p(x_{k-1}|y_{1:k-1}, x_k)
209
+ lp = np .log (normalize_exp (log_alpha )).reshape (self .S ,1 ) + self .logA .T
210
+ P = normalize_exp (lp , axis = 0 )
211
+
212
+ # Update
213
+ V1 = np .dot (V1 , P )
214
+ V2 = np .dot (V2 , P ) + I_S1S * P .reshape ((1 ,self .S ,self .S ))
215
+ V3 = np .dot (V3 , P ) + I_RR [:,y [k - 1 ]].reshape ((self .R ,1 ,1 ))* P .reshape ((1 ,self .S ,self .S ))
216
+
217
+ log_alpha = self .update (y [k ], log_alpha_pred )
218
+ p_xT = normalize_exp (log_alpha )
219
+
220
+ C1 = np .dot (V1 , p_xT .reshape (self .S ,1 ))
221
+ C2 = np .dot (V2 , p_xT .reshape (1 ,self .S ,1 )).reshape ((self .S ,self .S ))
222
+ C3 = np .dot (V3 , p_xT .reshape (1 ,self .S ,1 )).reshape ((self .R ,self .S ))
223
+ C3 [y [- 1 ],:] += p_xT
224
+
225
+ ll = log_sum_exp (log_alpha )
226
+
227
+ return C1 , C2 , C3 , ll , (V1 , V2 , V3 )
228
+
229
+ def train_EM (self , y , EPOCH = 10 ):
230
+
231
+ LL = np .zeros (EPOCH )
232
+ for e in range (EPOCH ):
233
+ C1 , C2 , C3 , ll , V = self .forward_only_SS (y )
234
+ LL [e ] = ll
235
+ p = normalize (C1 + 0.1 , axis = 0 ).reshape (self .S )
236
+ #print(p,np.size(p))
237
+ A = normalize (C2 , axis = 0 )
238
+ #print(A)
239
+ B = normalize (C3 , axis = 0 )
240
+ #print(B)
241
+ self .__init__ (p , A , B )
242
+ # print(ll)
243
+
244
+ return LL
245
+
0 commit comments