diff --git a/R/plot_mvgam_fc.R b/R/plot_mvgam_fc.R index 3b18ac59..a843a6dd 100644 --- a/R/plot_mvgam_fc.R +++ b/R/plot_mvgam_fc.R @@ -106,11 +106,6 @@ plot_mvgam_fc = function(object, series = 1, newdata, data_test, } } - # Use sensible ylimits for beta - if(object$family == 'beta'){ - ylim <- c(0, 1) - } - # Prediction indices for the particular series data_train <- object$obs_data ends <- seq(0, dim(mcmc_chains(object$model_output, 'ypred'))[2], @@ -252,8 +247,17 @@ plot_mvgam_fc = function(object, series = 1, newdata, data_test, dplyr::distinct() %>% dplyr::arrange(time) %>% dplyr::pull(y) - ylim <- c(min(cred, min(ytrain, na.rm = TRUE)), - max(cred, max(ytrain, na.rm = TRUE)) + 2) + + if(tolower(object$family) %in% c('beta', 'lognormal', 'gamma')){ + ylim <- c(min(cred, min(ytrain, na.rm = TRUE)), + max(cred, max(ytrain, na.rm = TRUE))) + ymin <- max(0, ylim[1]) + ymax <- min(1, ylim[2]) + ylim <- c(ymin, ymax) + } else { + ylim <- c(min(cred, min(ytrain, na.rm = TRUE)), + max(cred, max(ytrain, na.rm = TRUE))) + } } if(missing(ylab)){ @@ -530,7 +534,9 @@ plot.mvgam_forecast = function(x, series = 1, max(cred, max(ytrain, na.rm = TRUE)) * 1.1) if(object$family == 'beta'){ - ylim <- c(0, 1) + ymin <- max(0, ylim[1]) + ymax <- min(1, ylim[2]) + ylim <- c(ymin, ymax) } if(object$family %in% c('lognormal', 'Gamma')){ diff --git a/R/stan_utils.R b/R/stan_utils.R index 08d8e8d4..052b9e58 100644 --- a/R/stan_utils.R +++ b/R/stan_utils.R @@ -2588,9 +2588,26 @@ add_trend_predictors = function(trend_formula, trend_smooths_included <- FALSE # Add any multinormal smooth lines - if(any(grepl('multi_normal_prec', trend_model_file))){ + if(any(grepl('multi_normal_prec', trend_model_file)) | + any(grepl('// priors for smoothing parameters', trend_model_file))){ trend_smooths_included <- TRUE + # Replace any noncontiguous indices from trend model so names aren't + # conflicting with any possible indices in the observation model + if(any(grepl('idx', trend_model_file))){ + trend_model_file <- gsub('idx', 'trend_idx', trend_model_file) + idx_data <- trend_mvgam$model_data[grep('idx', names(trend_mvgam$model_data))] + names(idx_data) <- gsub('idx', 'trend_idx', names(idx_data)) + model_data <- append(model_data, idx_data) + + idx_lines <- grep('int trend_idx', trend_model_file) + model_file[min(grep('data {', model_file, fixed = TRUE))] <- + paste0('data {\n', + paste(trend_model_file[idx_lines], + collapse = '\n')) + model_file <- readLines(textConnection(model_file), n = -1) + } + if(any(grepl("int n_sp; // number of smoothing parameters", model_file, fixed = TRUE))){ model_file[grep("int n_sp; // number of smoothing parameters", @@ -2607,10 +2624,28 @@ add_trend_predictors = function(trend_formula, spline_coef_headers <- trend_model_file[grep('multi_normal_prec', trend_model_file) - 1] + if(any(grepl('normal(0, lambda', + trend_model_file, fixed = TRUE))){ + spline_coef_headers <- c(spline_coef_headers, + trend_model_file[grep('normal(0, lambda', + trend_model_file, fixed = TRUE)-1]) + } spline_coef_headers <- gsub('...', '_trend...', spline_coef_headers, fixed = TRUE) + spline_coef_lines <- trend_model_file[grepl('multi_normal_prec', trend_model_file)] + if(any(grepl('normal(0, lambda', + trend_model_file, fixed = TRUE))){ + lambda_normals <- (grep('normal(0, lambda', + trend_model_file, fixed = TRUE)) + for(i in 1:length(lambda_normals)){ + spline_coef_lines <- c(spline_coef_lines, + paste(trend_model_file[lambda_normals[i]], + collapse = '\n')) + } + } + spline_coef_lines <- gsub('_raw', '_raw_trend', spline_coef_lines) spline_coef_lines <- gsub('lambda', 'lambda_trend', spline_coef_lines) spline_coef_lines <- gsub('zero', 'zero_trend', spline_coef_lines) @@ -2681,23 +2716,28 @@ add_trend_predictors = function(trend_formula, } - S_lines <- trend_model_file[grep('mgcv smooth penalty matrix', - trend_model_file, fixed = TRUE)] - S_lines <- gsub('S', 'S_trend', S_lines, fixed = TRUE) - model_file[grep("int n_nonmissing; // number of nonmissing observations", - model_file, fixed = TRUE)] <- - paste0("int n_nonmissing; // number of nonmissing observations\n", - paste(S_lines, collapse = '\n')) + if(any(grepl('mgcv smooth penalty matrix', + trend_model_file, fixed = TRUE))){ + S_lines <- trend_model_file[grep('mgcv smooth penalty matrix', + trend_model_file, fixed = TRUE)] + S_lines <- gsub('S', 'S_trend', S_lines, fixed = TRUE) + model_file[grep("int n_nonmissing; // number of nonmissing observations", + model_file, fixed = TRUE)] <- + paste0("int n_nonmissing; // number of nonmissing observations\n", + paste(S_lines, collapse = '\n')) - S_mats <- trend_mvgam$model_data[paste0('S', 1:length(S_lines))] - names(S_mats) <- gsub('S', 'S_trend', names(S_mats)) - model_data <- append(model_data, S_mats) + S_mats <- trend_mvgam$model_data[paste0('S', 1:length(S_lines))] + names(S_mats) <- gsub('S', 'S_trend', names(S_mats)) + model_data <- append(model_data, S_mats) + } - model_file[grep("int num_basis_trend; // number of trend basis coefficients", - model_file, fixed = TRUE)] <- - paste0("int num_basis_trend; // number of trend basis coefficients\n", - "vector[num_basis_trend] zero_trend; // prior locations for trend basis coefficients") - model_data$zero_trend <- trend_mvgam$model_data$zero + if(!is.null(trend_mvgam$model_data$zero)){ + model_file[grep("int num_basis_trend; // number of trend basis coefficients", + model_file, fixed = TRUE)] <- + paste0("int num_basis_trend; // number of trend basis coefficients\n", + "vector[num_basis_trend] zero_trend; // prior locations for trend basis coefficients") + model_data$zero_trend <- trend_mvgam$model_data$zero + } if(any(grepl("vector[n_sp] rho;", model_file, fixed = TRUE))){ model_file[grep("vector[n_sp] rho;", model_file, fixed = TRUE)] <- diff --git a/src/mvgam.dll b/src/mvgam.dll index 4f52c29b..b52e44d4 100644 Binary files a/src/mvgam.dll and b/src/mvgam.dll differ diff --git a/tests/testthat/Rplots.pdf b/tests/testthat/Rplots.pdf index 9e6dab08..fdc1786a 100644 Binary files a/tests/testthat/Rplots.pdf and b/tests/testthat/Rplots.pdf differ diff --git a/tests/testthat/test-dynamic.R b/tests/testthat/test-dynamic.R index 1a0df406..6cbc5e47 100644 --- a/tests/testthat/test-dynamic.R +++ b/tests/testthat/test-dynamic.R @@ -48,3 +48,18 @@ test_that("rho argument cannot be larger than N - 1", { 'Argument "rho" in dynamic() cannot be larger than (max(time) - 1)', fixed = TRUE) }) + +test_that("dynamic works for trend_formulas", { + mod <- mvgam(y ~ dynamic(time, rho = 5), + trend_formula = ~ dynamic(time, rho = 15), + trend_model = 'RW', + data = beta_data$data_train, + family = betar(), + run_model = FALSE) + expect_true(inherits(mod, 'mvgam_prefit')) + + # trend_idx should be in the model file and in the model data + expect_true(any(grepl('trend_idx', mod$model_file))) + expect_true(!is.null(mod$model_data$trend_idx1)) +}) +