From 8e7d742a9ba44586121593f676de560c0352891c Mon Sep 17 00:00:00 2001 From: Koen Derks Date: Thu, 24 Oct 2024 18:24:13 +0200 Subject: [PATCH] Store all variables in saved model --- R/commonMachineLearningClassification.R | 4 +++- R/commonMachineLearningRegression.R | 6 ++++-- R/mlPrediction.R | 14 +++++++------- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/R/commonMachineLearningClassification.R b/R/commonMachineLearningClassification.R index 6345d19d..47a2d391 100644 --- a/R/commonMachineLearningClassification.R +++ b/R/commonMachineLearningClassification.R @@ -342,7 +342,9 @@ return() } model <- classificationResult[["model"]] - model[["jaspVars"]] <- decodeColNames(options[["predictors"]]) + model[["jaspVars"]] <- list() + model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]])) + model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]]) model[["jaspVersion"]] <- .baseCitation model[["explainer"]] <- classificationResult[["explainer"]] model <- .decodeJaspMLobject(model) diff --git a/R/commonMachineLearningRegression.R b/R/commonMachineLearningRegression.R index 7180bc26..34f9f6dc 100644 --- a/R/commonMachineLearningRegression.R +++ b/R/commonMachineLearningRegression.R @@ -454,7 +454,9 @@ return() } model <- regressionResult[["model"]] - model[["jaspVars"]] <- decodeColNames(options[["predictors"]]) + model[["jaspVars"]] <- list() + model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]])) + model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]]) model[["jaspVersion"]] <- .baseCitation model[["explainer"]] <- regressionResult[["explainer"]] model <- .decodeJaspMLobject(model) @@ -697,7 +699,7 @@ } else { purpose <- "classification" } - predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]])] + predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]][["decoded"]]$predictors)] } else { predictors <- options[["predictors"]] } diff --git a/R/mlPrediction.R b/R/mlPrediction.R index 65b0fecd..13dcee07 100644 --- a/R/mlPrediction.R +++ b/R/mlPrediction.R @@ -119,21 +119,21 @@ is.jaspMachineLearning <- function(x) { } .mlPredictionGetPredictions.nn <- function(model, dataset) { if (inherits(model, "jaspClassification")) { - as.character(levels(factor(model[["data"]][, 1]))[max.col(neuralnet:::predict.nn(model, newdata = dataset))]) + as.character(levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))[max.col(neuralnet:::predict.nn(model, newdata = dataset))]) } else if (inherits(model, "jaspRegression")) { as.numeric(neuralnet:::predict.nn(model, newdata = dataset)) } } .mlPredictionGetPredictions.rpart <- function(model, dataset) { if (inherits(model, "jaspClassification")) { - as.character(levels(factor(model[["data"]][, 1]))[max.col(rpart:::predict.rpart(model, newdata = dataset))]) + as.character(levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))[max.col(rpart:::predict.rpart(model, newdata = dataset))]) } else if (inherits(model, "jaspRegression")) { as.numeric(rpart:::predict.rpart(model, newdata = dataset)) } } .mlPredictionGetPredictions.svm <- function(model, dataset) { if (inherits(model, "jaspClassification")) { - as.character(levels(factor(model[["data"]][, 1]))[e1071:::predict.svm(model, newdata = dataset)]) + as.character(levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))[e1071:::predict.svm(model, newdata = dataset)]) } else if (inherits(model, "jaspRegression")) { as.numeric(e1071:::predict.svm(model, newdata = dataset)) } @@ -142,7 +142,7 @@ is.jaspMachineLearning <- function(x) { as.character(e1071:::predict.naiveBayes(model, newdata = dataset, type = "class")) } .mlPredictionGetPredictions.glm <- function(model, dataset) { - as.character(levels(as.factor(model$model[, 1]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1]) + as.character(levels(as.factor(model$model[[model[["jaspVars"]][["encoded"]]$target]]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1]) } .mlPredictionGetPredictions.vglm <- function(model, dataset) { model[["original"]]@terms$terms <- model[["terms"]] @@ -293,7 +293,7 @@ is.jaspMachineLearning <- function(x) { # also define methods for other objects .mlPredictionReady <- function(model, dataset, options) { if (!is.null(model)) { - modelVars <- model[["jaspVars"]] + modelVars <- model[["jaspVars"]][["decoded"]]$predictors presentVars <- decodeColNames(colnames(dataset)) ready <- all(modelVars %in% presentVars) } else { @@ -344,7 +344,7 @@ is.jaspMachineLearning <- function(x) { if (is.null(model)) { return() } - modelVars <- model[["jaspVars"]] + modelVars <- model[["jaspVars"]][["decoded"]]$predictors presentVars <- decodeColNames(colnames(dataset)) if (!all(modelVars %in% presentVars)) { missingVars <- modelVars[which(!(modelVars %in% presentVars))] @@ -422,7 +422,7 @@ is.jaspMachineLearning <- function(x) { selection <- predictions[indexes] cols <- list(row = indexes, pred = selection) if (options[["predictionsTableFeatures"]]) { - for (i in encodeColNames(model[["jaspVars"]])) { + for (i in model[["jaspVars"]][["encoded"]]$predictors) { if (.columnIsNominal(i)) { table$addColumnInfo(name = i, title = i, type = "string") var <- levels(dataset[[i]])[dataset[[i]]]