Skip to content

Commit 9a8ddcd

Browse files
author
Jouni Helske
committed
predict method with bootstrap
1 parent f751987 commit 9a8ddcd

File tree

5 files changed

+203
-36
lines changed

5 files changed

+203
-36
lines changed

R/RcppExports.R

+8
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,14 @@ predict_fanhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, o
173173
.Call(`_seqHMM_predict_fanhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, W_A, W_B)
174174
}
175175

176+
boot_predict_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, gamma_pi, gamma_A, gamma_B) {
177+
.Call(`_seqHMM_boot_predict_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, gamma_pi, gamma_A, gamma_B)
178+
}
179+
180+
boot_predict_fanhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, W_A, W_B, gamma_pi, gamma_A, gamma_B) {
181+
.Call(`_seqHMM_boot_predict_fanhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, W_A, W_B, gamma_pi, gamma_A, gamma_B)
182+
}
183+
176184
simulate_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B) {
177185
.Call(`_seqHMM_simulate_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B)
178186
}

R/predict.R

+69-6
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,43 @@ predict.nhmm <- function(object, newdata = NULL, ...) {
5151
rownames(object$observations),
5252
each = object$n_symbols * object$length_of_sequences
5353
),
54-
estimate = c(out$obs_prob)
55-
)
54+
estimate = c(out)
55+
) |>
56+
stats::na.omit()
57+
colnames(d)[2] <- time
58+
if (!is.null(object$boot)) {
59+
out <- boot_predict_nhmm_singlechannel(
60+
object$etas$pi, object$X_pi,
61+
object$etas$A, object$X_A,
62+
object$etas$B, object$X_B,
63+
array(obs[1, , ], dim(obs)[2:3]),
64+
object$sequence_lengths,
65+
attr(object$X_pi, "icpt_only"), attr(object$X_A, "icpt_only"),
66+
attr(object$X_B, "icpt_only"), attr(object$X_A, "iv"),
67+
attr(object$X_B, "iv"), attr(object$X_A, "tv"), attr(object$X_B, "tv"),
68+
object$boot$gamma_pi, object$boot$gamma_A, object$boot$gamma_B
69+
)
70+
d_boot <- data.frame(
71+
observation = object$symbol_names,
72+
time = rep(
73+
as.numeric(colnames(object$observations)),
74+
each = object$n_symbols
75+
),
76+
id = rep(
77+
rownames(object$observations),
78+
each = object$n_symbols * object$length_of_sequences
79+
),
80+
estimate = unlist(out$obs_prob)
81+
) |>
82+
stats::na.omit()
83+
colnames(d_boot)[2] <- time
84+
d$type <- "MLE"
85+
d_boot$type <- "Bootstrap"
86+
d <- rbind(d, d_boot)
87+
}
5688
} else {
5789
stop("Not yet implemented")
5890
}
59-
colnames(d)[2] <- time
6091
d
6192
}
6293

@@ -124,9 +155,41 @@ predict.fanhmm <- function(object, newdata = NULL, ...) {
124155
rownames(object$observations),
125156
each = object$n_symbols * object$length_of_sequences
126157
),
127-
estimate = c(out$obs_prob)
128-
)
129-
158+
estimate = c(out)
159+
) |>
160+
stats::na.omit()
130161
colnames(d)[2] <- time
162+
163+
if (!is.null(object$boot)) {
164+
out <- boot_predict_fanhmm_singlechannel(
165+
object$etas$pi, object$X_pi,
166+
object$etas$A, object$X_A,
167+
object$etas$B, object$X_B,
168+
array(obs[1, , ], dim(obs)[2:3]),
169+
object$sequence_lengths,
170+
attr(object$X_pi, "icpt_only"), attr(object$X_A, "icpt_only"),
171+
attr(object$X_B, "icpt_only"), attr(object$X_A, "iv"),
172+
attr(object$X_B, "iv"), attr(object$X_A, "tv"), attr(object$X_B, "tv"),
173+
W_A, W_B,
174+
object$boot$gamma_pi, object$boot$gamma_A, object$boot$gamma_B
175+
)
176+
d_boot <- data.frame(
177+
observation = object$symbol_names,
178+
time = rep(
179+
as.numeric(colnames(object$observations)),
180+
each = object$n_symbols
181+
),
182+
id = rep(
183+
rownames(object$observations),
184+
each = object$n_symbols * object$length_of_sequences
185+
),
186+
estimate = unlist(out$obs_prob)
187+
) |>
188+
stats::na.omit()
189+
colnames(d_boot)[2] <- time
190+
d$type <- "MLE"
191+
d_boot$type <- "Bootstrap"
192+
d <- rbind(d, d_boot)
193+
}
131194
d
132195
}

R/update.R

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ update.nhmm <- function(object, newdata, ...) {
6666
}
6767
)
6868
object$observations <- .check_observations(observations, object$channel_names)
69+
object$sequence_lengths <- attr(object$observations, "sequence_lengths")
6970
object
7071
}
7172
#' @rdname update_nhmm

src/RcppExports.cpp

+62-2
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ BEGIN_RCPP
910910
END_RCPP
911911
}
912912
// predict_nhmm_singlechannel
913-
Rcpp::List predict_nhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B);
913+
arma::cube predict_nhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B);
914914
RcppExport SEXP _seqHMM_predict_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP) {
915915
BEGIN_RCPP
916916
Rcpp::RObject rcpp_result_gen;
@@ -935,7 +935,7 @@ BEGIN_RCPP
935935
END_RCPP
936936
}
937937
// predict_fanhmm_singlechannel
938-
Rcpp::List predict_fanhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::field<arma::cube>& W_A, const arma::field<arma::cube>& W_B);
938+
arma::cube predict_fanhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::field<arma::cube>& W_A, const arma::field<arma::cube>& W_B);
939939
RcppExport SEXP _seqHMM_predict_fanhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP W_ASEXP, SEXP W_BSEXP) {
940940
BEGIN_RCPP
941941
Rcpp::RObject rcpp_result_gen;
@@ -961,6 +961,64 @@ BEGIN_RCPP
961961
return rcpp_result_gen;
962962
END_RCPP
963963
}
964+
// boot_predict_nhmm_singlechannel
965+
arma::field<arma::cube> boot_predict_nhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::field<arma::mat>& gamma_pi, const arma::field<arma::cube>& gamma_A, const arma::field<arma::cube>& gamma_B);
966+
RcppExport SEXP _seqHMM_boot_predict_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP gamma_piSEXP, SEXP gamma_ASEXP, SEXP gamma_BSEXP) {
967+
BEGIN_RCPP
968+
Rcpp::RObject rcpp_result_gen;
969+
Rcpp::RNGScope rcpp_rngScope_gen;
970+
Rcpp::traits::input_parameter< arma::mat& >::type eta_pi(eta_piSEXP);
971+
Rcpp::traits::input_parameter< const arma::mat& >::type X_pi(X_piSEXP);
972+
Rcpp::traits::input_parameter< arma::cube& >::type eta_A(eta_ASEXP);
973+
Rcpp::traits::input_parameter< const arma::cube& >::type X_A(X_ASEXP);
974+
Rcpp::traits::input_parameter< arma::cube& >::type eta_B(eta_BSEXP);
975+
Rcpp::traits::input_parameter< const arma::cube& >::type X_B(X_BSEXP);
976+
Rcpp::traits::input_parameter< const arma::umat& >::type obs(obsSEXP);
977+
Rcpp::traits::input_parameter< const arma::uvec >::type Ti(TiSEXP);
978+
Rcpp::traits::input_parameter< const bool >::type icpt_only_pi(icpt_only_piSEXP);
979+
Rcpp::traits::input_parameter< const bool >::type icpt_only_A(icpt_only_ASEXP);
980+
Rcpp::traits::input_parameter< const bool >::type icpt_only_B(icpt_only_BSEXP);
981+
Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP);
982+
Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP);
983+
Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP);
984+
Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP);
985+
Rcpp::traits::input_parameter< const arma::field<arma::mat>& >::type gamma_pi(gamma_piSEXP);
986+
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type gamma_A(gamma_ASEXP);
987+
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type gamma_B(gamma_BSEXP);
988+
rcpp_result_gen = Rcpp::wrap(boot_predict_nhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, gamma_pi, gamma_A, gamma_B));
989+
return rcpp_result_gen;
990+
END_RCPP
991+
}
992+
// boot_predict_fanhmm_singlechannel
993+
arma::field<arma::cube> boot_predict_fanhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::field<arma::cube>& W_A, const arma::field<arma::cube>& W_B, const arma::field<arma::mat>& gamma_pi, const arma::field<arma::cube>& gamma_A, const arma::field<arma::cube>& gamma_B);
994+
RcppExport SEXP _seqHMM_boot_predict_fanhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP W_ASEXP, SEXP W_BSEXP, SEXP gamma_piSEXP, SEXP gamma_ASEXP, SEXP gamma_BSEXP) {
995+
BEGIN_RCPP
996+
Rcpp::RObject rcpp_result_gen;
997+
Rcpp::RNGScope rcpp_rngScope_gen;
998+
Rcpp::traits::input_parameter< arma::mat& >::type eta_pi(eta_piSEXP);
999+
Rcpp::traits::input_parameter< const arma::mat& >::type X_pi(X_piSEXP);
1000+
Rcpp::traits::input_parameter< arma::cube& >::type eta_A(eta_ASEXP);
1001+
Rcpp::traits::input_parameter< const arma::cube& >::type X_A(X_ASEXP);
1002+
Rcpp::traits::input_parameter< arma::cube& >::type eta_B(eta_BSEXP);
1003+
Rcpp::traits::input_parameter< const arma::cube& >::type X_B(X_BSEXP);
1004+
Rcpp::traits::input_parameter< const arma::umat& >::type obs(obsSEXP);
1005+
Rcpp::traits::input_parameter< const arma::uvec >::type Ti(TiSEXP);
1006+
Rcpp::traits::input_parameter< const bool >::type icpt_only_pi(icpt_only_piSEXP);
1007+
Rcpp::traits::input_parameter< const bool >::type icpt_only_A(icpt_only_ASEXP);
1008+
Rcpp::traits::input_parameter< const bool >::type icpt_only_B(icpt_only_BSEXP);
1009+
Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP);
1010+
Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP);
1011+
Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP);
1012+
Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP);
1013+
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type W_A(W_ASEXP);
1014+
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type W_B(W_BSEXP);
1015+
Rcpp::traits::input_parameter< const arma::field<arma::mat>& >::type gamma_pi(gamma_piSEXP);
1016+
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type gamma_A(gamma_ASEXP);
1017+
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type gamma_B(gamma_BSEXP);
1018+
rcpp_result_gen = Rcpp::wrap(boot_predict_fanhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, W_A, W_B, gamma_pi, gamma_A, gamma_B));
1019+
return rcpp_result_gen;
1020+
END_RCPP
1021+
}
9641022
// simulate_nhmm_singlechannel
9651023
Rcpp::List simulate_nhmm_singlechannel(const arma::mat& eta_pi, const arma::mat& X_pi, const arma::cube& eta_A, const arma::cube& X_A, const arma::cube& eta_B, const arma::cube& X_B);
9661024
RcppExport SEXP _seqHMM_simulate_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP) {
@@ -1692,6 +1750,8 @@ static const R_CallMethodDef CallEntries[] = {
16921750
{"_seqHMM_log_objective_mnhmm_multichannel", (DL_FUNC) &_seqHMM_log_objective_mnhmm_multichannel, 18},
16931751
{"_seqHMM_predict_nhmm_singlechannel", (DL_FUNC) &_seqHMM_predict_nhmm_singlechannel, 15},
16941752
{"_seqHMM_predict_fanhmm_singlechannel", (DL_FUNC) &_seqHMM_predict_fanhmm_singlechannel, 17},
1753+
{"_seqHMM_boot_predict_nhmm_singlechannel", (DL_FUNC) &_seqHMM_boot_predict_nhmm_singlechannel, 18},
1754+
{"_seqHMM_boot_predict_fanhmm_singlechannel", (DL_FUNC) &_seqHMM_boot_predict_fanhmm_singlechannel, 20},
16951755
{"_seqHMM_simulate_nhmm_singlechannel", (DL_FUNC) &_seqHMM_simulate_nhmm_singlechannel, 6},
16961756
{"_seqHMM_simulate_nhmm_multichannel", (DL_FUNC) &_seqHMM_simulate_nhmm_multichannel, 7},
16971757
{"_seqHMM_simulate_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_simulate_mnhmm_singlechannel, 8},

src/nhmm_predict.cpp

+63-28
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "mnhmm_mc.h"
77

88
// [[Rcpp::export]]
9-
Rcpp::List predict_nhmm_singlechannel(
9+
arma::cube predict_nhmm_singlechannel(
1010
arma::mat& eta_pi, const arma::mat& X_pi,
1111
arma::cube& eta_A, const arma::cube& X_A,
1212
arma::cube& eta_B, const arma::cube& X_B,
@@ -21,12 +21,10 @@ Rcpp::List predict_nhmm_singlechannel(
2121
arma::cube obs_prob(model.M, model.T, model.N, arma::fill::value(arma::datum::nan));
2222
model.predict(obs_prob);
2323

24-
return Rcpp::List::create(
25-
Rcpp::Named("obs_prob") = Rcpp::wrap(obs_prob)
26-
);
24+
return obs_prob;
2725
}
2826
// [[Rcpp::export]]
29-
Rcpp::List predict_fanhmm_singlechannel(
27+
arma::cube predict_fanhmm_singlechannel(
3028
arma::mat& eta_pi, const arma::mat& X_pi,
3129
arma::cube& eta_A, const arma::cube& X_A,
3230
arma::cube& eta_B, const arma::cube& X_B,
@@ -42,9 +40,65 @@ Rcpp::List predict_fanhmm_singlechannel(
4240
arma::cube obs_prob(model.M, model.T, model.N, arma::fill::value(arma::datum::nan));
4341
model.predict_fanhmm(obs_prob, W_A, W_B);
4442

45-
return Rcpp::List::create(
46-
Rcpp::Named("obs_prob") = Rcpp::wrap(obs_prob)
43+
return obs_prob;
44+
}
45+
46+
// [[Rcpp::export]]
47+
arma::field<arma::cube> boot_predict_nhmm_singlechannel(
48+
arma::mat& eta_pi, const arma::mat& X_pi,
49+
arma::cube& eta_A, const arma::cube& X_A,
50+
arma::cube& eta_B, const arma::cube& X_B,
51+
const arma::umat& obs, const arma::uvec Ti,
52+
const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B,
53+
const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B,
54+
const arma::field<arma::mat>& gamma_pi,
55+
const arma::field<arma::cube>& gamma_A,
56+
const arma::field<arma::cube>& gamma_B) {
57+
58+
nhmm_sc model(
59+
eta_A.n_slices, X_pi, X_A, X_B, Ti, icpt_only_pi, icpt_only_A,
60+
icpt_only_B, iv_A, iv_B, tv_A, tv_B, obs, eta_pi, eta_A, eta_B
61+
);
62+
arma::uword nsim = gamma_pi.n_elem;
63+
arma::field<arma::cube> obs_prob(nsim);
64+
for (arma::uword j = 0; j < nsim; j++) {
65+
model.gamma_pi = gamma_pi(j);
66+
model.gamma_A = gamma_A(j);
67+
model.gamma_B = gamma_B(j);
68+
obs_prob(j) = arma::cube(model.M, model.T, model.N, arma::fill::value(arma::datum::nan));
69+
model.predict(obs_prob(j));
70+
71+
}
72+
return obs_prob;
73+
}
74+
// [[Rcpp::export]]
75+
arma::field<arma::cube> boot_predict_fanhmm_singlechannel(
76+
arma::mat& eta_pi, const arma::mat& X_pi,
77+
arma::cube& eta_A, const arma::cube& X_A,
78+
arma::cube& eta_B, const arma::cube& X_B,
79+
const arma::umat& obs, const arma::uvec Ti,
80+
const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B,
81+
const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B,
82+
const arma::field<arma::cube>& W_A, const arma::field<arma::cube>& W_B,
83+
const arma::field<arma::mat>& gamma_pi,
84+
const arma::field<arma::cube>& gamma_A,
85+
const arma::field<arma::cube>& gamma_B) {
86+
87+
nhmm_sc model(
88+
eta_A.n_slices, X_pi, X_A, X_B, Ti, icpt_only_pi, icpt_only_A,
89+
icpt_only_B, iv_A, iv_B, tv_A, tv_B, obs, eta_pi, eta_A, eta_B
4790
);
91+
arma::uword nsim = gamma_pi.n_elem;
92+
arma::field<arma::cube> obs_prob(nsim);
93+
for (arma::uword j = 0; j < nsim; j++) {
94+
model.gamma_pi = gamma_pi(j);
95+
model.gamma_A = gamma_A(j);
96+
model.gamma_B = gamma_B(j);
97+
obs_prob(j) = arma::cube(model.M, model.T, model.N, arma::fill::value(arma::datum::nan));
98+
model.predict_fanhmm(obs_prob(j), W_A, W_B);
99+
100+
}
101+
return obs_prob;
48102
}
49103
void nhmm_sc::predict(arma::cube& obs_prob) {
50104

@@ -94,12 +148,7 @@ void nhmm_sc::predict_fanhmm(
94148
// P(z_1)
95149
alpha = pi;
96150
// P(y_1)
97-
// if (obs(0, i) < M) {
98-
// obs_prob.slice(i).col(0).zeros();
99-
// obs_prob(obs(0, i), 0, i) = 1.0;
100-
// } else {
101-
obs_prob.slice(i).col(0) = B.slice(0).cols(0, M - 1).t() * alpha;
102-
// }
151+
obs_prob.slice(i).col(0) = B.slice(0).cols(0, M - 1).t() * alpha;
103152
// P(z_1) P(y_1| y_1) = P(z_1, y_1)
104153
alpha %= B.slice(0).col(obs(0, i));
105154
// P(z_1 | y_1)
@@ -109,12 +158,7 @@ void nhmm_sc::predict_fanhmm(
109158
// P(alpha_t | y_t-1, ..., y_1)
110159
alpha = A.slice(t - 1).t() * alpha;
111160
// P(y_t | y_t-1,...,y_1)
112-
// if (obs(t, i) < M) {
113-
// obs_prob.slice(i).col(t).zeros();
114-
// obs_prob(obs(t, i), t, i) = 1.0;
115-
// } else {
116-
obs_prob.slice(i).col(t) = B.slice(t).cols(0, M - 1).t() * alpha;
117-
// }
161+
obs_prob.slice(i).col(t) = B.slice(t).cols(0, M - 1).t() * alpha;
118162
// P(alpha_t, y_t | y_t-1, ..., y_1) (or P(alpha_t | y_t-1, ..., y_1) if y_t missing)
119163
alpha %= B.slice(t).col(obs(t, i));
120164
// P(alpha_t | y_t, ..., y_1) (or P(alpha_t | y_t-1, ..., y_1) if y_t missing)
@@ -132,18 +176,9 @@ void nhmm_sc::predict_fanhmm(
132176
// P(alpha_t | y_t-1 = m, y_t-2, ..., y_1)
133177
alpha_new.col(m) = A_tm1.slice(m).t() * alpha * obs_prob(m, t - 1, i);
134178
obs_prob.slice(i).col(t) += B_t.slice(m).cols(0, M - 1).t() * alpha_new.col(m);
135-
// alpha_new.col(m) = A_tm1.slice(m).t() * alpha * obs_prob(m, t - 1, i);
136-
// if (obs(t, i) < M) {
137-
// obs_prob.slice(i).col(t).zeros();
138-
// obs_prob(obs(t, i), t, i) = 1.0;
139-
// } else {
140-
// // obs prediction when previous y=m
141-
// obs_prob.slice(i).col(t) += B_t.slice(m).cols(0, M - 1).t() * alpha_new.col(m);
142-
// }
143179
// P(alpha_t, y_t | y_t-1 = m, y_t-2, ..., y_1)
144180
alpha_new.col(m) %= B_t.slice(m).col(obs(t, i));
145181
}
146-
147182
// P(alpha_t, y_t | y_t-2,..., y_1)
148183
alpha = arma::sum(alpha_new, 1);
149184
// P(alpha_t | y_t, y_t-2, ..., y_1)

0 commit comments

Comments
 (0)