Skip to content

Commit 07848a7

Browse files
Merge pull request #54 from JeffreyCHoover/aic-bic
AIC and BIC
2 parents d558c3f + bd7d1ef commit 07848a7

8 files changed

+133
-11
lines changed

Diff for: NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
* Added new `measrfit()` function for creating measrfit objects from *Stan* models that were not originally created with measr.
66

7+
* Added `aic()` and `bic()` functions for calculating the Akaike and Bayesian information criteria, respectively, for models estimated with `method = "optim"`.
8+
79
# measr 1.0.0
810

911
## New documentation

Diff for: R/aic-bic.R

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
aic <- function(model) {
2+
model <- check_model(model, required_class = "measrfit", name = "model")
3+
log_lik <- model$model$value
4+
5+
num_params <- model$model$par %>%
6+
tibble::as_tibble() %>%
7+
dplyr::mutate(param = names(model$model$par)) %>%
8+
dplyr::filter(!grepl("pi", .data$param),
9+
!grepl("log_Vc", .data$param)) %>%
10+
nrow() - 1
11+
12+
aic <- (-2 * log_lik) + (2 * num_params)
13+
14+
return(aic)
15+
}
16+
17+
bic <- function(model) {
18+
model <- check_model(model, required_class = "measrfit", name = "model")
19+
log_lik <- model$model$value
20+
21+
num_params <- model$model$par %>%
22+
tibble::as_tibble() %>%
23+
dplyr::mutate(param = names(model$model$par)) %>%
24+
dplyr::filter(!grepl("pi", .data$param),
25+
!grepl("log_Vc", .data$param)) %>%
26+
nrow() - 1
27+
28+
n <- model$data$data %>%
29+
dplyr::distinct(.data$resp_id) %>%
30+
nrow()
31+
32+
bic <- (-2 * log_lik) + (log(n) * num_params)
33+
34+
return(bic)
35+
}

Diff for: R/model-evaluation.R

+22-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
#' @inheritParams dcm2::calc_m2
1111
#' @param x A [measrfit] object.
1212
#' @param criterion A vector of criteria to calculate and add to the model
13-
#' object.
13+
#' object. Must be one of `"loo"` or `"waic"` for models estimated with MCMC,
14+
#' or one of `"aic"` or `"bic"` for model estimated with the optimizer.
1415
#' @param method A vector of model fit methods to evaluate and add to the model
1516
#' object.
1617
#' @param probs The percentiles to be computed by the [stats::quantile()]
@@ -117,16 +118,23 @@ NULL
117118

118119
#' @export
119120
#' @rdname model_evaluation
120-
add_criterion <- function(x, criterion = c("loo", "waic"), overwrite = FALSE,
121-
save = TRUE, ..., r_eff = NA) {
121+
add_criterion <- function(x, criterion = c("loo", "waic"),
122+
overwrite = FALSE, save = TRUE, ..., r_eff = NA) {
122123
model <- check_model(x, required_class = "measrfit", name = "x")
123-
if (model$method != "mcmc") {
124+
if (any(model$method != "mcmc" && any(criterion %in% c("loo", "waic")))) {
124125
rlang::abort("error_bad_method",
125-
message = glue::glue("Model criteria are only available for ",
126-
"models estimated with ",
126+
message = glue::glue("LOO and WAIC model criteria are only ",
127+
"available for models estimated with ",
127128
"`method = \"mcmc\"`."))
129+
} else if (any(model$method != "optim" &&
130+
any(criterion %in% c("aic", "bic")))) {
131+
rlang::abort("error_bad_method",
132+
message = glue::glue("AIC and BIC model criteria are only ",
133+
"available for models estimated with ",
134+
"`method = \"optim\"`."))
128135
}
129-
criterion <- rlang::arg_match(criterion, values = c("loo", "waic"),
136+
criterion <- rlang::arg_match(criterion,
137+
values = c("loo", "waic", "aic", "bic"),
130138
multiple = TRUE)
131139
overwrite <- check_logical(overwrite, name = "overwrite")
132140
save <- check_logical(save, name = "force_save")
@@ -140,7 +148,7 @@ add_criterion <- function(x, criterion = c("loo", "waic"), overwrite = FALSE,
140148
}
141149
all_criteria <- c(new_criteria, redo_criteria)
142150

143-
if (length(all_criteria) > 0) {
151+
if (length(all_criteria) > 0 && (model$method == "mcmc")) {
144152
log_lik_array <- loglik_array(model)
145153
}
146154

@@ -150,6 +158,12 @@ add_criterion <- function(x, criterion = c("loo", "waic"), overwrite = FALSE,
150158
if ("waic" %in% all_criteria) {
151159
model$criteria$waic <- waic(log_lik_array)
152160
}
161+
if ("aic" %in% all_criteria) {
162+
model$criteria$aic <- aic(model)
163+
}
164+
if ("bic" %in% all_criteria) {
165+
model$criteria$bic <- bic(model)
166+
}
153167

154168
# re-save model object (if applicable)
155169
if (!is.null(model$file) && length(all_criteria) > 0 && save) {

Diff for: man/model_evaluation.Rd

+2-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: tests/testthat/test-ecpe.R

+11
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,14 @@ test_that("mcmc requirements error", {
293293
expect_s3_class(err, "error_bad_method")
294294
expect_match(err$message, "`method = \"mcmc\"`")
295295
})
296+
297+
test_that("optim requirements error", {
298+
skip_on_cran()
299+
300+
mcmc_mod <- cmds_ecpe_lcdm
301+
mcmc_mod$method <- "mcmc"
302+
303+
err <- rlang::catch_cnd(add_criterion(mcmc_mod, criterion = c("aic", "bic")))
304+
expect_s3_class(err, "error_bad_method")
305+
expect_match(err$message, "`method = \"optim\"`")
306+
})

Diff for: tests/testthat/test-mcmc.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ test_that("loo and waic can be added to model", {
116116
expect_equal(names(loo_model$criteria), "loo")
117117
expect_s3_class(loo_model$criteria$loo, "psis_loo")
118118

119-
lw_model <- add_criterion(loo_model, criterion = c("loo", "waic"),
120-
overwrite = TRUE)
119+
lw_model <- add_criterion(loo_model, overwrite = TRUE)
121120
expect_equal(names(lw_model$criteria), c("loo", "waic"))
122121
expect_s3_class(lw_model$criteria$loo, "psis_loo")
123122
expect_s3_class(lw_model$criteria$waic, "waic")

Diff for: tests/testthat/test-model-evaluation.R

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
test_that("add criterion error messages work", {
2+
err <- rlang::catch_cnd(add_criterion("test"))
3+
expect_s3_class(err, "error_bad_argument")
4+
expect_match(err$message, "must be an object with class measrfit")
5+
6+
7+
err <- rlang::catch_cnd(add_criterion(rstn_dino, criterion = "waic"))
8+
expect_s3_class(err, "error_bad_method")
9+
expect_match(err$message, "LOO and WAIC model criteria are only available")
10+
11+
err <- rlang::catch_cnd(add_criterion(rstn_dino, criterion = "waic"))
12+
expect_s3_class(err, "error_bad_method")
13+
expect_match(err$message, "LOO and WAIC model criteria are only available")
14+
15+
test_dino <- rstn_dino
16+
test_dino$method <- "mcmc"
17+
err <- rlang::catch_cnd(add_criterion(test_dino, criterion = "aic"))
18+
expect_s3_class(err, "error_bad_method")
19+
expect_match(err$message, "AIC and BIC model criteria are only available")
20+
21+
err <- rlang::catch_cnd(add_criterion(test_dino, criterion = "bic"))
22+
expect_s3_class(err, "error_bad_method")
23+
expect_match(err$message, "AIC and BIC model criteria are only available")
24+
})
25+
26+
test_that("AIC works", {
27+
rstn_dino <- add_criterion(rstn_dino, criterion = "aic")
28+
29+
exp_aic <- 37151.96
30+
31+
expect_equal(rstn_dino$criteria$aic, exp_aic, tolerance = .0001)
32+
})
33+
34+
test_that("BIC works", {
35+
rstn_dino <- add_criterion(rstn_dino, criterion = "bic")
36+
37+
exp_bic <- 37647.64
38+
39+
expect_equal(rstn_dino$criteria$bic, exp_bic, tolerance = .0001)
40+
})

Diff for: tests/testthat/test-utils-model-evaluation.R

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
test_that("aic works", {
2+
num_params <- 101
3+
log_lik <- -18474.98
4+
5+
exp_aic <- (-2 * log_lik) + (2 * num_params)
6+
aic_val <- aic(rstn_dino)
7+
8+
expect_equal(exp_aic, aic_val)
9+
})
10+
11+
test_that("bic works", {
12+
num_params <- 101
13+
n <- 1000
14+
log_lik <- -18474.98
15+
16+
exp_bic <- (-2 * log_lik) + (log(n) * num_params)
17+
bic_val <- bic(rstn_dino)
18+
19+
expect_equal(exp_bic, bic_val)
20+
})

0 commit comments

Comments
 (0)