Skip to content

Commit 7cf7281

Browse files
committed
Add factory methods for summary, proposal, kernel functions
1 parent f02d864 commit 7cf7281

File tree

4 files changed

+103
-16
lines changed

4 files changed

+103
-16
lines changed

Diff for: R/cpp11.R

+12
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,10 @@ set_observed_data_cpp <- function(lfmcmc, observed_data_) {
232232
.Call(`_epiworldR_set_observed_data_cpp`, lfmcmc, observed_data_)
233233
}
234234

235+
create_LFMCMCProposalFun_cpp <- function(fun) {
236+
.Call(`_epiworldR_create_LFMCMCProposalFun_cpp`, fun)
237+
}
238+
235239
set_proposal_fun_cpp <- function(lfmcmc, fun) {
236240
.Call(`_epiworldR_set_proposal_fun_cpp`, lfmcmc, fun)
237241
}
@@ -244,10 +248,18 @@ set_simulation_fun_cpp <- function(lfmcmc, fun) {
244248
.Call(`_epiworldR_set_simulation_fun_cpp`, lfmcmc, fun)
245249
}
246250

251+
create_LFMCMCSummaryFun_cpp <- function(fun) {
252+
.Call(`_epiworldR_create_LFMCMCSummaryFun_cpp`, fun)
253+
}
254+
247255
set_summary_fun_cpp <- function(lfmcmc, fun) {
248256
.Call(`_epiworldR_set_summary_fun_cpp`, lfmcmc, fun)
249257
}
250258

259+
create_LFMCMCKernelFun_cpp <- function(fun) {
260+
.Call(`_epiworldR_create_LFMCMCKernelFun_cpp`, fun)
261+
}
262+
251263
set_kernel_fun_cpp <- function(lfmcmc, fun) {
252264
.Call(`_epiworldR_set_kernel_fun_cpp`, lfmcmc, fun)
253265
}

Diff for: src/cpp11.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,13 @@ extern "C" SEXP _epiworldR_set_observed_data_cpp(SEXP lfmcmc, SEXP observed_data
412412
END_CPP11
413413
}
414414
// lfmcmc.cpp
415+
SEXP create_LFMCMCProposalFun_cpp(cpp11::function fun);
416+
extern "C" SEXP _epiworldR_create_LFMCMCProposalFun_cpp(SEXP fun) {
417+
BEGIN_CPP11
418+
return cpp11::as_sexp(create_LFMCMCProposalFun_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::function>>(fun)));
419+
END_CPP11
420+
}
421+
// lfmcmc.cpp
415422
SEXP set_proposal_fun_cpp(SEXP lfmcmc, SEXP fun);
416423
extern "C" SEXP _epiworldR_set_proposal_fun_cpp(SEXP lfmcmc, SEXP fun) {
417424
BEGIN_CPP11
@@ -433,13 +440,27 @@ extern "C" SEXP _epiworldR_set_simulation_fun_cpp(SEXP lfmcmc, SEXP fun) {
433440
END_CPP11
434441
}
435442
// lfmcmc.cpp
443+
SEXP create_LFMCMCSummaryFun_cpp(cpp11::function fun);
444+
extern "C" SEXP _epiworldR_create_LFMCMCSummaryFun_cpp(SEXP fun) {
445+
BEGIN_CPP11
446+
return cpp11::as_sexp(create_LFMCMCSummaryFun_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::function>>(fun)));
447+
END_CPP11
448+
}
449+
// lfmcmc.cpp
436450
SEXP set_summary_fun_cpp(SEXP lfmcmc, SEXP fun);
437451
extern "C" SEXP _epiworldR_set_summary_fun_cpp(SEXP lfmcmc, SEXP fun) {
438452
BEGIN_CPP11
439453
return cpp11::as_sexp(set_summary_fun_cpp(cpp11::as_cpp<cpp11::decay_t<SEXP>>(lfmcmc), cpp11::as_cpp<cpp11::decay_t<SEXP>>(fun)));
440454
END_CPP11
441455
}
442456
// lfmcmc.cpp
457+
SEXP create_LFMCMCKernelFun_cpp(cpp11::function fun);
458+
extern "C" SEXP _epiworldR_create_LFMCMCKernelFun_cpp(SEXP fun) {
459+
BEGIN_CPP11
460+
return cpp11::as_sexp(create_LFMCMCKernelFun_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::function>>(fun)));
461+
END_CPP11
462+
}
463+
// lfmcmc.cpp
443464
SEXP set_kernel_fun_cpp(SEXP lfmcmc, SEXP fun);
444465
extern "C" SEXP _epiworldR_set_kernel_fun_cpp(SEXP lfmcmc, SEXP fun) {
445466
BEGIN_CPP11
@@ -1021,7 +1042,10 @@ static const R_CallMethodDef CallEntries[] = {
10211042
{"_epiworldR_agents_smallworld_cpp", (DL_FUNC) &_epiworldR_agents_smallworld_cpp, 5},
10221043
{"_epiworldR_change_state_cpp", (DL_FUNC) &_epiworldR_change_state_cpp, 4},
10231044
{"_epiworldR_clone_model_cpp", (DL_FUNC) &_epiworldR_clone_model_cpp, 1},
1045+
{"_epiworldR_create_LFMCMCKernelFun_cpp", (DL_FUNC) &_epiworldR_create_LFMCMCKernelFun_cpp, 1},
1046+
{"_epiworldR_create_LFMCMCProposalFun_cpp", (DL_FUNC) &_epiworldR_create_LFMCMCProposalFun_cpp, 1},
10241047
{"_epiworldR_create_LFMCMCSimFun_cpp", (DL_FUNC) &_epiworldR_create_LFMCMCSimFun_cpp, 1},
1048+
{"_epiworldR_create_LFMCMCSummaryFun_cpp", (DL_FUNC) &_epiworldR_create_LFMCMCSummaryFun_cpp, 1},
10251049
{"_epiworldR_distribute_entity_randomly_cpp", (DL_FUNC) &_epiworldR_distribute_entity_randomly_cpp, 3},
10261050
{"_epiworldR_distribute_entity_to_set_cpp", (DL_FUNC) &_epiworldR_distribute_entity_to_set_cpp, 1},
10271051
{"_epiworldR_distribute_tool_randomly_cpp", (DL_FUNC) &_epiworldR_distribute_tool_randomly_cpp, 2},

Diff for: src/lfmcmc.cpp

+58-6
Original file line numberDiff line numberDiff line change
@@ -46,27 +46,44 @@ SEXP set_observed_data_cpp(
4646
return lfmcmc;
4747
}
4848

49+
// LFMCMC Proposal Function
50+
[[cpp11::register]]
51+
SEXP create_LFMCMCProposalFun_cpp(
52+
cpp11::function fun
53+
) {
54+
55+
LFMCMCProposalFun<TData_default> fun_call = [fun](std::vector< epiworld_double >& params_now,const std::vector< epiworld_double >& params_prev, LFMCMC<TData_default>* model) -> void {
56+
WrapLFMCMC(lfmcmc_ptr)(model);
57+
fun(params_now, params_prev, lfmcmc_ptr);
58+
return;
59+
};
60+
61+
return cpp11::external_pointer<LFMCMCProposalFun<TData_default>>(
62+
new LFMCMCProposalFun<TData_default>(fun_call)
63+
);
64+
}
65+
4966
[[cpp11::register]]
5067
SEXP set_proposal_fun_cpp(
5168
SEXP lfmcmc,
5269
SEXP fun
5370
) {
54-
cpp11::external_pointer<LFMCMCProposalFun<TData_default>> fun_ptr(fun);
71+
cpp11::external_pointer<LFMCMCProposalFun<TData_default>> fun_ptr = create_LFMCMCProposalFun_cpp(fun);
5572
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
5673
lfmcmc_ptr->set_proposal_fun(*fun_ptr);
5774
return lfmcmc;
5875
}
5976

77+
// LFMCMC Simulation Function
6078
[[cpp11::register]]
6179
SEXP create_LFMCMCSimFun_cpp(
6280
cpp11::function fun
6381
) {
6482

6583
LFMCMCSimFun<TData_default> fun_call = [fun](const std::vector<epiworld_double>& params, LFMCMC<TData_default>* model) -> TData_default {
6684
WrapLFMCMC(lfmcmc_ptr)(model);
67-
SEXP res = fun(params, lfmcmc_ptr);
68-
cpp11::external_pointer<TData_default> res_vec(res);
69-
return *res_vec;
85+
cpp11::external_pointer<TData_default> res(fun(params, lfmcmc_ptr));
86+
return *res;
7087
};
7188

7289
return cpp11::external_pointer<LFMCMCSimFun<TData_default>>(
@@ -85,23 +102,58 @@ SEXP set_simulation_fun_cpp(
85102
return lfmcmc;
86103
}
87104

105+
// LFMCMC Summary Function
106+
[[cpp11::register]]
107+
SEXP create_LFMCMCSummaryFun_cpp(
108+
cpp11::function fun
109+
) {
110+
111+
LFMCMCSummaryFun<TData_default> fun_call = [fun](std::vector< epiworld_double >& res, const TData_default& dat, LFMCMC<TData_default>* model) -> void {
112+
WrapLFMCMC(lfmcmc_ptr)(model);
113+
fun(res, dat, lfmcmc_ptr);
114+
return;
115+
};
116+
117+
return cpp11::external_pointer<LFMCMCSummaryFun<TData_default>>(
118+
new LFMCMCSummaryFun<TData_default>(fun_call)
119+
);
120+
}
121+
88122
[[cpp11::register]]
89123
SEXP set_summary_fun_cpp(
90124
SEXP lfmcmc,
91125
SEXP fun
92126
) {
93-
cpp11::external_pointer<LFMCMCSummaryFun<TData_default>> fun_ptr(fun);
127+
cpp11::external_pointer<LFMCMCSummaryFun<TData_default>> fun_ptr = create_LFMCMCSummaryFun_cpp(fun);
94128
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
95129
lfmcmc_ptr->set_summary_fun(*fun_ptr);
96130
return lfmcmc;
97131
}
98132

133+
// LFMCMC Kernel Function
134+
// TODO: clean up these really long lines
135+
[[cpp11::register]]
136+
SEXP create_LFMCMCKernelFun_cpp(
137+
cpp11::function fun
138+
) {
139+
140+
LFMCMCKernelFun<TData_default> fun_call = [fun](const std::vector< epiworld_double >& stats_now, const std::vector< epiworld_double >& stats_obs, epiworld_double epsilon, LFMCMC<TData_default>* model) -> epiworld_double {
141+
WrapLFMCMC(lfmcmc_ptr)(model);
142+
cpp11::external_pointer<epiworld_double> res(fun(stats_now, stats_obs, epsilon, lfmcmc_ptr));
143+
return *res;
144+
};
145+
146+
return cpp11::external_pointer<LFMCMCKernelFun<TData_default>>(
147+
new LFMCMCKernelFun<TData_default>(fun_call)
148+
);
149+
}
150+
99151
[[cpp11::register]]
100152
SEXP set_kernel_fun_cpp(
101153
SEXP lfmcmc,
102154
SEXP fun
103155
) {
104-
cpp11::external_pointer<LFMCMCKernelFun<TData_default>> fun_ptr(fun);
156+
cpp11::external_pointer<LFMCMCKernelFun<TData_default>> fun_ptr = create_LFMCMCKernelFun_cpp(fun);
105157
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
106158
lfmcmc_ptr->set_kernel_fun(*fun_ptr);
107159
return lfmcmc;

Diff for: vignettes/likelihood-free-mcmc.Rmd

+9-10
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,21 @@ simfun <- function(params, m) {
7676
}
7777
# TODO: Define Summary Function
7878
sumfun <- function(res, dat, m) {
79-
# if (res.size() == 0)
79+
# if (length(res) == 0)
8080
# res.resize(data.size())
8181
8282
# for (i in dat.size())
8383
# res[i] = static_cast<float>(dat[i])
8484
85-
# return
85+
return
8686
}
8787
# TODO: Define Proposal Function
88-
propfun <- function(scale, lb, ub) {
89-
88+
propfun <- function(params_now, params_prev, m) {
89+
return
9090
}
9191
# TODO: Define Kernel Function
9292
kernfun <- function() {
93-
93+
return(1.0)
9494
}
9595
9696
# Set initial parameters
@@ -101,11 +101,10 @@ par0 <- c(.5, .5)
101101
```{r lfmcmc-run}
102102
# TODO: make these work
103103
lfmcmc_model <- LFMCMC() |>
104-
set_simulation_fun(simfun)
105-
# set_simulation_fun(lfmcmc_model, simfun)
106-
# set_summary_fun(sumfun) |>
107-
# set_proposal_fun(propfun) |>
108-
# set_kernel_fun(kernfun) |>
104+
set_simulation_fun(simfun) |>
105+
set_summary_fun(sumfun) |>
106+
set_proposal_fun(propfun) |>
107+
set_kernel_fun(kernfun)
109108
# set_observed_data(obs_dat) |>
110109
# run_lfmcmc(par0, 2000, 1)
111110

0 commit comments

Comments
 (0)