Skip to content

Commit a20e672

Browse files
committed
added marginal plots
1 parent a746f45 commit a20e672

File tree

4 files changed

+176
-6
lines changed

4 files changed

+176
-6
lines changed

R/dplot3_calibration.R

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#' Draw calibration plot
66
#'
7+
#' @inheritParams dplot3_xy
78
#' @param true.labels Factor or list of factors with true class labels
89
#' @param est.prob Numeric vector or list of numeric vectors with predicted probabilities
910
#' @param bin.method Character: "quantile" or "equidistant": Method to bin the estimated
@@ -14,6 +15,7 @@
1415
#' @param subtitle Character: Subtitle, placed bottom right of plot
1516
#' @param xlab Character: x-axis label
1617
#' @param ylab Character: y-axis label
18+
#' @param show.marginal.x Logical: Add marginal plot of distribution of estimated probabilities
1719
#' @param mode Character: Plot mode
1820
#' @param filename Character: Path to save output.
1921
#' @param ... Additional arguments passed to [dplot3_xy]
@@ -52,10 +54,17 @@ dplot3_calibration <- function(true.labels, est.prob,
5254
subtitle = NULL,
5355
xlab = "Mean estimated probability",
5456
ylab = "Empirical risk",
57+
show.marginal.x = TRUE,
58+
marginal.x.y = -.02,
59+
marginal.col = NULL,
60+
marginal.size = 10,
61+
show.bins = TRUE,
5562
# conf_level = .95,
5663
mode = "markers+lines",
5764
print.brier = TRUE,
65+
theme = rtTheme,
5866
filename = NULL, ...) {
67+
# Arguments ----
5968
bin.method <- match.arg(bin.method)
6069
if (is.null(pos.class)) {
6170
pos.class <- rtenv$binclasspos
@@ -69,6 +78,10 @@ dplot3_calibration <- function(true.labels, est.prob,
6978
# Ensure same number of inputs
7079
stopifnot(length(true.labels) == length(est.prob))
7180

81+
# Theme ----
82+
if (is.character(theme)) {
83+
theme <- do.call(paste0("theme_", theme), list())
84+
}
7285
pos_class <- lapply(true.labels, \(x) {
7386
levels(x)[pos.class]
7487
})
@@ -136,22 +149,52 @@ dplot3_calibration <- function(true.labels, est.prob,
136149
)
137150
}
138151
# if (is.null(subtitle) && !is.na(subtitle)) .subtitle <- paste0(subtitle, "\n", .subtitle)
139-
dplot3_xy(
152+
plt <- dplot3_xy(
140153
x = mean_bin_prob,
141154
y = window_empirical_risk,
142155
main = main,
143156
# subtitle = paste("<i>", .subtitle, "</i>"),
144157
subtitle = subtitle,
145158
subtitle.x = 1,
146-
subtitle.y = .01,
159+
subtitle.y = 0,
160+
subtitle.yref = "y",
147161
subtitle.xanchor = "right",
148162
subtitle.yanchor = "bottom",
149163
xlab = xlab,
150164
ylab = ylab,
151-
axes.square = TRUE, diagonal = TRUE,
165+
show.marginal.x = show.marginal.x,
166+
marginal.x = est.prob,
167+
marginal.x.y = marginal.x.y,
168+
marginal.size = marginal.size,
169+
axes.square = TRUE,
170+
diagonal = TRUE,
152171
xlim = c(0, 1), ylim = c(0, 1),
153172
mode = mode,
173+
theme = theme,
154174
filename = filename, ...
155175
)
156176

177+
# Add marginal.x ----
178+
# Using estimated probabilities
179+
# if (marginal.x) {
180+
# if (is.null(marginal.col)) marginal.col <- plotly::toRGB(theme$fg, alpha = .5)
181+
# for (i in seq_along(mean_bin_prob)) {
182+
# plt <- plotly::add_trace(
183+
# plt,
184+
# x = est.prob[[i]],
185+
# y = rep(-.02, length(est.prob[[i]])),
186+
# type = "scatter",
187+
# mode = "markers",
188+
# marker = list(
189+
# color = marginal.col,
190+
# size = marginal.size,
191+
# symbol = "line-ns-open"
192+
# ),
193+
# showlegend = FALSE,
194+
# hoverinfo = "x"
195+
# )
196+
# }
197+
# } # /marginal.x
198+
199+
plt
157200
} # rtemis::dplot3_calibration

R/dplot3_xy.R

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
#' @param main.yanchor Character: "top", "middle", "bottom"
4646
#' @param subtitle.x Numeric: X position of subtitle relative to paper
4747
#' @param subtitle.y Numeric: Y position of subtitle relative to paper
48+
#' @param subtitle.xref Character: "paper", "x", "y"
49+
#' @param subtitle.yref Character: "paper", "x", "y"
4850
#' @param subtitle.xanchor Character: "left", "center", "right"
4951
#' @param subtitle.yanchor Character: "top", "middle", "bottom"
5052
#' @param scrollZoom Logical: If TRUE, enable scroll zoom
@@ -53,6 +55,16 @@
5355
#' @param symbol Character: Marker symbol.
5456
#' @param scatter.type Character: "scatter", "scattergl", "scatter3d", "scatterternary",
5557
#' "scatterpolar", "scattermapbox",
58+
#' @param show.marginal.x Logical: If TRUE, add marginal distribution line markers on x-axis
59+
#' @param show.marginal.y Logical: If TRUE, add marginal distribution line markers on y-axis
60+
#' @param marginal.x Numeric: Data whose distribution will be shown on x-axis. Only
61+
#' specify if different from `x`
62+
#' @param marginal.y Numeric: Data whose distribution will be shown on y-axis. Only
63+
#' specify if different from `y`
64+
#' @param marginal.x.y Numeric: Y position of marginal markers on x-axis
65+
#' @param marginal.col Color for marginal markers
66+
#' @param marginal.alpha Numeric: Alpha for marginal markers
67+
#' @param marginal.size Numeric: Size of marginal markers
5668
#'
5769
#' @author E.D. Gennatas
5870
#' @export
@@ -98,6 +110,16 @@ dplot3_xy <- function(x, y = NULL,
98110
se.col = NULL,
99111
se.alpha = .4,
100112
scatter.type = "scatter",
113+
# Marginal plots
114+
show.marginal.x = FALSE,
115+
show.marginal.y = FALSE,
116+
marginal.x = x,
117+
marginal.y = y,
118+
marginal.x.y = NULL,
119+
marginal.y.x = NULL,
120+
marginal.col = NULL,
121+
marginal.alpha = .333,
122+
marginal.size = 5,
101123
legend = NULL,
102124
legend.xy = c(0, .98),
103125
legend.xanchor = "left",
@@ -120,6 +142,8 @@ dplot3_xy <- function(x, y = NULL,
120142
main.yanchor = "bottom",
121143
subtitle.x = 0.02,
122144
subtitle.y = 0.99,
145+
subtitle.xref = "paper",
146+
subtitle.yref = "paper",
123147
subtitle.xanchor = "left",
124148
subtitle.yanchor = "top",
125149
automargin.x = TRUE,
@@ -267,6 +291,10 @@ dplot3_xy <- function(x, y = NULL,
267291
}
268292
}
269293

294+
# Marginal data ----
295+
if (show.marginal.x && is.null(marginal.x)) marginal.x <- x
296+
if (show.marginal.y && is.null(marginal.y)) marginal.y <- y
297+
270298
# Reorder ----
271299
if (order.on.x) {
272300
index <- lapply(x, order)
@@ -496,8 +524,57 @@ dplot3_xy <- function(x, y = NULL,
496524
legendgroup = .names[i],
497525
showlegend = legend
498526
)
527+
# Marginal plots ----
528+
# Add marginal plots by plotting short vertical markers on the x and y axes
529+
if (show.marginal.x) {
530+
if (is.null(marginal.col)) {
531+
marginal.col <- plotly::toRGB(marker.col, alpha = marginal.alpha)
532+
}
533+
if (is.null(marginal.x.y)) marginal.x.y <- ylim[1]
534+
for (i in seq_len(n.groups)) {
535+
plt <- plotly::add_trace(plt,
536+
x = marginal.x[[i]],
537+
y = rep(marginal.x.y, length(marginal.x[[i]])),
538+
type = "scatter",
539+
mode = "markers",
540+
marker = list(
541+
color = marginal.col[[i]],
542+
size = marginal.size,
543+
symbol = "line-ns-open"
544+
),
545+
showlegend = FALSE,
546+
hoverinfo = "x"
547+
# legendgroup = .names[i],
548+
# inherit = FALSE
549+
)
550+
}
551+
} # /show.marginal.x
552+
553+
if (show.marginal.y) {
554+
if (is.null(marginal.col)) {
555+
marginal.col <- plotly::toRGB(marker.col, alpha = marginal.alpha)
556+
}
557+
if (is.null(marginal.y.x)) marginal.y.x <- xlim[1]
558+
for (i in seq_len(n.groups)) {
559+
plt <- plotly::add_trace(plt,
560+
x = rep(marginal.y.x, length(marginal.y[[i]])),
561+
y = marginal.y[[i]],
562+
type = "scatter",
563+
mode = "markers",
564+
marker = list(
565+
color = marginal.col[[i]],
566+
size = marginal.size,
567+
symbol = "line-ew-open"
568+
),
569+
showlegend = FALSE,
570+
hoverinfo = "y"
571+
# legendgroup = .names[i]
572+
)
573+
}
574+
} # /show.marginal.y
575+
576+
## { SE band } ----
499577
if (se.fit) {
500-
## { SE band } ----
501578
plt <- plotly::add_trace(plt,
502579
x = x[[i]],
503580
y = fitted[[i]] + se.times * se[[i]],
@@ -678,8 +755,8 @@ dplot3_xy <- function(x, y = NULL,
678755
plt <- plt |> plotly::add_annotations(
679756
x = subtitle.x,
680757
y = subtitle.y,
681-
xref = "paper",
682-
yref = "paper",
758+
xref = subtitle.xref,
759+
yref = subtitle.yref,
683760
xanchor = subtitle.xanchor,
684761
yanchor = subtitle.yanchor,
685762
text = subtitle,

man/dplot3_calibration.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/dplot3_xy.Rd

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)