Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: robyn_response()'s plot labels improvement #1212

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: Robyn
Type: Package
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
Version: 3.12.0.9006
Version: 3.12.0.9007
Authors@R: c(
person("Gufeng", "Zhou", , "[email protected]", c("cre", "aut")),
person("Igor", "Skokan", , "[email protected]", c("aut")),
Expand Down
11 changes: 4 additions & 7 deletions R/R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -928,11 +928,10 @@ check_metric_dates <- function(date_range = NULL, all_dates, dayInterval = NULL,
} else if (is.Date(as.Date(date_range[1]))) {
## Using dates as date_range range
date_range_updated <- date_range <- as.Date(date_range, origin = "1970-01-01")
if (!all(date_range %in% all_dates)) {
date_range_loc <- range(sapply(date_range, FUN = function(x) which.min(abs(x - all_dates))))
date_range_loc <- seq(from = date_range_loc[1], to = date_range_loc[2], by = 1)
} else {
date_range_loc <- which(all_dates %in% date_range)
if (length(date_range) == 2) {
date_range_loc <- which(all_dates >= date_range[1] & all_dates <= date_range[2])
} else if (length(date_range) == 1) {
date_range_loc <- sapply(date_range, FUN = function(x) which.min(abs(x - all_dates)))
}
date_range_updated <- all_dates[date_range_loc]
} else {
Expand Down Expand Up @@ -964,8 +963,6 @@ check_metric_value <- function(metric_value, metric_name, all_values, metric_loc
# message(paste0("'metric_value'", metric_value, " splitting into ", get_n, " periods evenly"))
} else if (get_n == 1 & length(metric_value) == 1) {
metric_value_updated <- metric_value
} else {
stop("robyn_response metric_value & date_range must have same length\n")
}
}
all_values_updated <- all_values
Expand Down
96 changes: 58 additions & 38 deletions R/R/response.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ robyn_response <- function(InputCollect = NULL,
date_range = NULL,
dt_hyppar = NULL,
dt_coef = NULL,
plots = TRUE,
quiet = FALSE,
...) {
## Get input
Expand Down Expand Up @@ -177,7 +178,7 @@ robyn_response <- function(InputCollect = NULL,
metric_name_updated <- metric_type$metric_name_updated
all_dates <- dt_input[[InputCollect$date_var]]
all_values <- dt_mod[[metric_name_updated]]
ds_list <- check_metric_dates(date_range = date_range, all_dates[1:window_end_loc], dayInterval, quiet, ...)
ds_list <- check_metric_dates(date_range, all_dates[1:window_end_loc], dayInterval, quiet, ...)
val_list <- check_metric_value(metric_value, metric_name_updated, all_values, ds_list$metric_loc)
if (!is.null(metric_value) & is.null(date_range)) {
stop("Must specify date_range when using metric_value")
Expand Down Expand Up @@ -219,6 +220,8 @@ robyn_response <- function(InputCollect = NULL,
dt_point_sim <- data.frame(
input = hist_transform$sim_mean_spend + hist_transform$sim_mean_carryover,
output = hist_transform$sim_mean_response)
} else {
dt_point_sim <- NULL
}

## Simulated transformation
Expand All @@ -234,44 +237,60 @@ robyn_response <- function(InputCollect = NULL,
}

## Plot optimal response
p_res <- ggplot(dt_line, aes(x = .data$metric, y = .data$response)) +
geom_line(color = "steelblue") +
geom_point(
data = dt_point,
aes(x = .data$mean_input_total, y = .data$mean_response_total),
size = 3, color = "grey") +
labs(
title = paste(
"Saturation curve of", metric_type$metric_type,
"media:", metric_type$metric_name_updated
),
subtitle = sprintf(paste(
"Response: %s @ mean input %s",
"Response: %s @ mean input carryover %s",
"Response: %s @ mean input immediate %s",
sep = "\n"),
num_abbr(dt_point$mean_response_total),
num_abbr(dt_point$mean_input_total),
num_abbr(dt_point$mean_response_carryover),
num_abbr(dt_point$mean_input_carryover),
num_abbr(dt_point$mean_response_immediate),
num_abbr(dt_point$mean_input_immediate)
),
x = "Input", y = "Response",
caption = sprintf(
"Response period: %s%s%s",
head(date_range_updated, 1),
ifelse(length(date_range_updated) > 1, paste(" to", tail(date_range_updated, 1)), ""),
ifelse(length(date_range_updated) > 1, paste0(" [", length(date_range_updated), " periods]"), "")
)
) +
theme_lares(background = "white") +
scale_x_abbr() +
scale_y_abbr()
if (!is.null(metric_value) | !is.null(date_range)) {
p_res <- p_res +
geom_point(data = dt_point_sim, aes(x = .data$input, y = .data$output), size = 3, color = "blue")
if (isTRUE(plots)) {
# # Add c(0,0) as first point of the curve?
# dt_line <- bind_rows(
# data.frame(metric = 0, response = 0, channel = dt_line$channel[1]),
# dt_line)
p_res <- ggplot(dt_line, aes(x = .data$metric, y = .data$response)) +
geom_line(color = "steelblue") +
geom_point(
data = dt_point,
aes(x = .data$mean_input_total, y = .data$mean_response_total),
size = 3, color = "grey") +
labs(
title = paste(
"Saturation curve of", metric_type$metric_type,
"media:", metric_type$metric_name_updated
),
subtitle = sprintf(paste(
"%s response @ mean total input %s",
" %s response @ mean carryover input %s",
" %s response @ mean immediate input %s",
sep = "\n"),
num_abbr(dt_point$mean_response_total),
num_abbr(dt_point$mean_input_total),
num_abbr(dt_point$mean_response_carryover),
num_abbr(dt_point$mean_input_carryover),
num_abbr(dt_point$mean_response_immediate),
num_abbr(dt_point$mean_input_immediate)
),
x = sprintf("Input Metric per %s (%s)", InputCollect$intervalType, metric_type$metric_name_updated),
y = sprintf("Response (%s)", InputCollect$dep_var),
caption = sprintf(
"Response period: %s%s%s%s",
head(date_range_updated, 1),
ifelse(length(date_range_updated) > 1, paste(" to", tail(date_range_updated, 1)), ""),
ifelse(length(date_range_updated) > 1, paste0(" [", length(date_range_updated), " periods]"), ""),
paste("\nTotal Input Metric for period:", formatNum(ifelse(
!is.null(metric_value), metric_value, sum(hist_transform$input_total)),
abbr = TRUE))
)
) +
theme_lares(background = "white") +
scale_x_abbr() +
scale_y_abbr()
if (!is.null(dt_point_sim)) {
p_res <- p_res +
geom_point(data = dt_point_sim, aes(x = .data$input, y = .data$output), size = 3, color = "blue") +
labs(caption = paste0(
p_res$labels$caption,
sprintf("\n%s response @ input %s", num_abbr(dt_point_sim$output), num_abbr(dt_point_sim$input))))
}
} else {
p_res <- NULL
}

if (!is.null(metric_value)) {
sim_mean_spend <- hist_transform_sim$sim_mean_spend
sim_mean_carryover <- hist_transform_sim$sim_mean_carryover
Expand All @@ -282,6 +301,7 @@ robyn_response <- function(InputCollect = NULL,

ret <- list(
metric_name = metric_name_updated,
metric_value = val_list$metric_value_updated,
date = date_range_updated,
input_total = hist_transform$input_total,
input_carryover = hist_transform$input_carryover,
Expand Down
3 changes: 3 additions & 0 deletions R/man/robyn_response.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.