Skip to content

Commit ef65227

Browse files
committed
add test for interval prediction
1 parent 3d77e01 commit ef65227

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

tests/testthat/test-layer_predict.R

+30-12
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,40 @@
1+
jhu <- case_death_rate_subset %>%
2+
dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
3+
r <- epi_recipe(jhu) %>%
4+
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
5+
step_epi_ahead(death_rate, ahead = 7) %>%
6+
step_naomit(all_predictors()) %>%
7+
step_naomit(all_outcomes(), skip = TRUE)
8+
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
9+
latest <- jhu %>%
10+
dplyr::filter(time_value >= max(time_value) - 14)
11+
12+
113
test_that("predict layer works alone", {
2-
jhu <- case_death_rate_subset %>%
3-
dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
4-
r <- epi_recipe(jhu) %>%
5-
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
6-
step_epi_ahead(death_rate, ahead = 7) %>%
7-
step_naomit(all_predictors()) %>%
8-
step_naomit(all_outcomes(), skip = TRUE)
9-
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
10-
latest <- jhu %>%
11-
dplyr::filter(time_value >= max(time_value) - 14)
1214

1315
f <- frosting() %>% layer_predict()
14-
wf <- wf %>% add_frosting(f)
16+
wf1 <- wf %>% add_frosting(f)
1517

16-
expect_silent(p <- predict(wf, latest))
18+
expect_silent(p <- predict(wf1, latest))
1719
expect_equal(ncol(p), 3L)
1820
expect_s3_class(p, "epi_df")
1921
expect_equal(nrow(p), 108L)
2022
expect_named(p, c("geo_value", "time_value", ".pred"))
2123

2224
})
25+
26+
test_that("prediction with interval works", {
27+
28+
f <- frosting() %>% layer_predict(type = "pred_int")
29+
wf2 <- wf %>% add_frosting(f)
30+
31+
expect_silent(p <- predict(wf2, latest))
32+
expect_equal(ncol(p), 4L)
33+
expect_s3_class(p, "epi_df")
34+
expect_equal(nrow(p), 108L)
35+
expect_named(p, c("geo_value", "time_value", ".pred_lower", ".pred_upper"))
36+
37+
38+
39+
})
40+

0 commit comments

Comments
 (0)