Skip to content

Commit 963c237

Browse files
committed
code refactor and addition of .ppc_rootogram_data
1 parent 4f6eae7 commit 963c237

File tree

2 files changed

+158
-116
lines changed

2 files changed

+158
-116
lines changed

R/ppc-discrete.R

Lines changed: 157 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -256,103 +256,23 @@ ppc_rootogram <- function(y,
256256
bound_distinct = TRUE) {
257257
check_ignored_arguments(...)
258258
style <- match.arg(style)
259-
y <- validate_y(y)
260-
yrep <- validate_predictions(yrep, length(y))
261-
if (!all_counts(y)) {
262-
abort("ppc_rootogram expects counts as inputs to 'y'.")
263-
}
264-
if (!all_counts(yrep)) {
265-
abort("ppc_rootogram expects counts as inputs to 'yrep'.")
266-
}
267-
268-
alpha <- (1 - prob) / 2
269-
probs <- c(alpha, 1 - alpha)
270-
ymax <- max(y, yrep)
271-
xpos <- 0L:ymax
272-
273-
# prepare a table for yrep
274-
tyrep <- as.list(rep(NA, nrow(yrep)))
275-
for (i in seq_along(tyrep)) {
276-
tyrep[[i]] <- table(yrep[i,])
277-
matches <- match(xpos, rownames(tyrep[[i]]))
278-
tyrep[[i]] <- as.numeric(tyrep[[i]][matches])
279-
}
280-
tyrep <- do.call(rbind, tyrep)
281-
tyrep[is.na(tyrep)] <- 0
282-
283-
#Discrete style
284-
pred_median <- apply(tyrep, 2, median)
285-
pred_quantile <- t(apply(tyrep, 2, quantile, probs = probs))
286-
colnames(pred_quantile) <- c("lower", "upper")
287-
288-
# prepare a table for y
289-
ty <- table(y)
290-
y_count <- as.numeric(ty[match(xpos, rownames(ty))])
291-
y_count[is.na(y_count)] <- 0
292-
293-
if (style == "discrete") {
294-
if (bound_distinct) {
295-
# If the observed count is within the bounds of the predicted quantiles,
296-
# use a different shape for the point
297-
obs_shape <- obs_shape <- ifelse(y_count >= pred_quantile[, "lower"] & y_count <= pred_quantile[, "upper"], "In", "Out")
298-
} else {
299-
obs_shape <- rep("y", length(y_count)) # all points are the same shape for observed
300-
}
301259

302-
data <- data.frame(
303-
xpos = xpos,
304-
obs = y_count,
305-
pred_median = pred_median,
306-
lower = pred_quantile[, "lower"],
307-
upper = pred_quantile[, "upper"],
308-
obs_shape = obs_shape
309-
)
310-
# Create the graph
311-
graph <- ggplot(data, aes(x = xpos)) +
312-
geom_pointrange(aes(y = pred_median, ymin = lower, ymax = upper, color = "y_rep"), fill = get_color("lh"), linewidth = size, size = size, fatten = 2, alpha = 1) +
313-
geom_point(aes(y = obs, shape = obs_shape), size = size * 1.5, color = get_color("d"), fill = get_color("d")) +
314-
scale_y_sqrt() +
315-
scale_fill_manual("", values = get_color("d"), guide="none") +
316-
scale_color_manual("", values = get_color("lh"), labels = yrep_label()) +
317-
labs(x = expression(italic(y)), y = "Count") +
318-
bayesplot_theme_get() +
319-
reduce_legend_spacing(0.25) +
320-
scale_shape_manual(values = c("In" = 22, "Out" = 23, "y" = 22), guide = "legend", labels = c("y" = expression(italic(y))))
321-
if (bound_distinct) {
322-
graph <- graph +
323-
guides(shape = guide_legend(expression(italic(y)~within~bounds)))
324-
} else {
325-
graph <- graph +
326-
guides(shape = guide_legend(" "))
327-
}
328-
return(graph)
329-
}
330-
331-
332-
#Standing, hanging, and suspended styles
333-
tyexp <- sqrt(colMeans(tyrep))
334-
tyquantile <- sqrt(t(apply(tyrep, 2, quantile, probs = probs)))
335-
colnames(tyquantile) <- c("tylower", "tyupper")
336-
337-
# prepare a table for y
338-
ty <- table(y)
339-
ty <- sqrt(as.numeric(ty[match(xpos, rownames(ty))]))
340-
if (style == "suspended") {
341-
ty <- tyexp - ty
342-
}
343-
ty[is.na(ty)] <- 0
344-
ypos <- ty / 2
345-
if (style == "hanging") {
346-
ypos <- tyexp - ypos
347-
}
260+
data <- .ppc_rootogram_data(
261+
y = y,
262+
yrep = yrep,
263+
style = style,
264+
prob = prob,
265+
bound_distinct = bound_distinct
266+
)
348267

349-
data <- data.frame(xpos, ypos, ty, tyexp, tyquantile)
350-
graph <- ggplot(data) +
351-
aes(
352-
ymin = .data$tylower,
353-
ymax = .data$tyupper,
354-
height = .data$ty
355-
) +
268+
# Building geoms for y and y_rep
269+
geom_y <- if (style == "discrete") {
270+
geom_point(
271+
aes(y = .data$obs, shape = .data$obs_shape),
272+
size = size * 1.5,
273+
color = get_color("d"),
274+
fill = get_color("d"))
275+
} else {
356276
geom_tile(
357277
aes(
358278
x = .data$xpos,
@@ -362,34 +282,69 @@ ppc_rootogram <- function(y,
362282
color = get_color("lh"),
363283
linewidth = 0.25,
364284
width = 1
365-
) +
366-
bayesplot_theme_get()
367-
368-
if (style != "standing") {
369-
graph <- graph + hline_0(size = 0.4)
285+
)
370286
}
371287

372-
graph <- graph +
288+
geom_yrep <- if (style == "discrete") {
289+
geom_pointrange(
290+
aes(y = .data$pred_median, ymin = .data$lower, ymax = .data$upper, color = "y_rep"),
291+
fill = get_color("lh"),
292+
linewidth = size,
293+
size = size,
294+
fatten = 2,
295+
alpha = 1
296+
)
297+
} else {
373298
geom_smooth(
374-
aes(
375-
x = .data$xpos,
376-
y = .data$tyexp,
377-
color = "Expected"
378-
),
299+
aes(x = .data$xpos, y = .data$tyexp, color = "Expected"),
379300
fill = get_color("d"),
380301
linewidth = size,
381302
stat = "identity"
382-
) +
383-
scale_fill_manual("", values = get_color("l")) +
384-
scale_color_manual("", values = get_color("dh")) +
385-
labs(x = expression(italic(y)),
386-
y = expression(sqrt(Count)))
387-
388-
if (style == "standing") {
389-
graph <- graph + dont_expand_y_axis()
303+
)
390304
}
391305

392-
graph + reduce_legend_spacing(0.25)
306+
# Creating the graph
307+
graph <- ggplot(data)
308+
309+
if (style == "discrete") {
310+
graph <- graph +
311+
geom_yrep +
312+
geom_y +
313+
aes(x = xpos) +
314+
scale_y_sqrt() +
315+
scale_fill_manual("", values = get_color("d"), guide = "none") +
316+
scale_color_manual("", values = get_color("lh"), labels = yrep_label()) +
317+
labs(x = expression(italic(y)), y = "Count") +
318+
bayesplot_theme_get() +
319+
reduce_legend_spacing(0.25) +
320+
scale_shape_manual(values = c("In" = 22, "Out" = 23, "y" = 22), guide = "legend", labels = c("y" = expression(italic(y))))
321+
if (bound_distinct) {
322+
graph <- graph + guides(shape = guide_legend(expression(italic(y)~within~bounds)))
323+
} else {
324+
graph <- graph + guides(shape = guide_legend(" "))
325+
}
326+
} else {
327+
graph <- graph +
328+
geom_y +
329+
geom_yrep +
330+
aes(
331+
ymin = .data$tylower,
332+
ymax = .data$tyupper,
333+
height = .data$ty
334+
) +
335+
scale_fill_manual("", values = get_color("l")) +
336+
scale_color_manual("", values = get_color("dh")) +
337+
labs(x = expression(italic(y)), y = expression(sqrt(Count))) +
338+
bayesplot_theme_get() +
339+
reduce_legend_spacing(0.25)
340+
if (style == "standing") {
341+
graph <- graph + dont_expand_y_axis()
342+
} else {
343+
graph <- graph + hline_0(size = 0.4)
344+
}
345+
}
346+
347+
return(graph)
393348
}
394349

395350

@@ -504,3 +459,90 @@ bars_group_facets <- function(facet_args, scales_default = "fixed") {
504459
fixed_y <- function(facet_args) {
505460
!isTRUE(facet_args[["scales"]] %in% c("free", "free_y"))
506461
}
462+
463+
#' Internal function for `ppc_rootogram()`
464+
#' @param y,yrep User's `y` and `yrep` arguments.
465+
#' @param style,prob,bound_distinct User's `style`, `prob`, and
466+
#' (if applicable) `bound_distinct` arguments.
467+
#' @noRd
468+
.ppc_rootogram_data <- function(y,
469+
yrep,
470+
style = c("standing", "hanging", "suspended", "discrete"),
471+
prob = 0.9,
472+
bound_distinct) {
473+
474+
y <- validate_y(y)
475+
yrep <- validate_predictions(yrep, length(y))
476+
if (!all_counts(y)) {
477+
abort("ppc_rootogram expects counts as inputs to 'y'.")
478+
}
479+
if (!all_counts(yrep)) {
480+
abort("ppc_rootogram expects counts as inputs to 'yrep'.")
481+
}
482+
483+
alpha <- (1 - prob) / 2
484+
probs <- c(alpha, 1 - alpha)
485+
ymax <- max(y, yrep)
486+
xpos <- 0L:ymax
487+
488+
# prepare a table for yrep
489+
tyrep <- as.list(rep(NA, nrow(yrep)))
490+
for (i in seq_along(tyrep)) {
491+
tyrep[[i]] <- table(yrep[i,])
492+
matches <- match(xpos, rownames(tyrep[[i]]))
493+
tyrep[[i]] <- as.numeric(tyrep[[i]][matches])
494+
}
495+
tyrep <- do.call(rbind, tyrep)
496+
tyrep[is.na(tyrep)] <- 0
497+
498+
# discrete style
499+
if (style == "discrete"){
500+
pred_median <- apply(tyrep, 2, median)
501+
pred_quantile <- t(apply(tyrep, 2, quantile, probs = probs))
502+
colnames(pred_quantile) <- c("lower", "upper")
503+
504+
# prepare a table for y
505+
ty <- table(y)
506+
y_count <- as.numeric(ty[match(xpos, rownames(ty))])
507+
y_count[is.na(y_count)] <- 0
508+
509+
if (bound_distinct) {
510+
# If the observed count is within the bounds of the predicted quantiles,
511+
# use a different shape for the point
512+
obs_shape <- obs_shape <- ifelse(y_count >= pred_quantile[, "lower"] & y_count <= pred_quantile[, "upper"], "In", "Out")
513+
} else {
514+
obs_shape <- rep("y", length(y_count)) # all points are the same shape for observed
515+
}
516+
517+
data <- data.frame(
518+
xpos = xpos,
519+
obs = y_count,
520+
pred_median = pred_median,
521+
lower = pred_quantile[, "lower"],
522+
upper = pred_quantile[, "upper"],
523+
obs_shape = obs_shape
524+
)
525+
}
526+
# standing, hanging, suspended styles
527+
else {
528+
tyexp <- sqrt(colMeans(tyrep))
529+
tyquantile <- sqrt(t(apply(tyrep, 2, quantile, probs = probs)))
530+
colnames(tyquantile) <- c("tylower", "tyupper")
531+
532+
# prepare a table for y
533+
ty <- table(y)
534+
ty <- sqrt(as.numeric(ty[match(xpos, rownames(ty))]))
535+
if (style == "suspended") {
536+
ty <- tyexp - ty
537+
}
538+
ty[is.na(ty)] <- 0
539+
ypos <- ty / 2
540+
if (style == "hanging") {
541+
ypos <- tyexp - ypos
542+
}
543+
544+
data <- data.frame(xpos, ypos, ty, tyexp, tyquantile)
545+
}
546+
547+
return(data)
548+
}

tests/testthat/_snaps/ppc-discrete/ppc-rootogram-style-hanging-prob-size.svg

Lines changed: 1 addition & 1 deletion
Loading

0 commit comments

Comments
 (0)