diff --git a/NAMESPACE b/NAMESPACE index 4d1a9118..3405d737 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -56,6 +56,7 @@ S3method(print,compare.loo) S3method(print,compare.loo_ss) S3method(print,importance_sampling) S3method(print,importance_sampling_loo) +S3method(print,kfold) S3method(print,loo) S3method(print,pareto_k_table) S3method(print,pseudobma_bb_weights) diff --git a/R/print.R b/R/print.R index f05834d6..6a3f2139 100644 --- a/R/print.R +++ b/R/print.R @@ -105,6 +105,28 @@ print.importance_sampling <- function(x, digits = 1, plot_k = FALSE, ...) { invisible(x) } +#' @export +#' @rdname print.loo +print.kfold <- function(x, digits = 1, plot_k = FALSE, ...) { + print.loo(x, digits = digits, ...) + + if ("diagnostics" %in% names(x)) { + cat("------\n") + S <- dim(x)[1] + k_threshold <- ps_khat_threshold(S) + if (length(pareto_k_ids(x, threshold = k_threshold))) { + cat("\n") + } + print(pareto_k_table(x), digits = digits) + cat(.k_help()) + + if (plot_k) { + graphics::plot(x, ...) + } + } + return(invisible(x)) + } + # internal ---------------------------------------------------------------- #' Print dimensions of log-likelihood or log-weights matrix diff --git a/man/print.loo.Rd b/man/print.loo.Rd index 78a5bedd..42a37bbe 100644 --- a/man/print.loo.Rd +++ b/man/print.loo.Rd @@ -8,6 +8,7 @@ \alias{print.psis_loo_ap} \alias{print.psis} \alias{print.importance_sampling} +\alias{print.kfold} \title{Print methods} \usage{ \method{print}{loo}(x, digits = 1, ...) @@ -23,6 +24,8 @@ \method{print}{psis}(x, digits = 1, plot_k = FALSE, ...) \method{print}{importance_sampling}(x, digits = 1, plot_k = FALSE, ...) + +\method{print}{kfold}(x, digits = 1, plot_k = FALSE, ...) } \arguments{ \item{x}{An object returned by \code{\link[=loo]{loo()}}, \code{\link[=psis]{psis()}}, or \code{\link[=waic]{waic()}}.} diff --git a/tests/testthat/_snaps/print_plot.md b/tests/testthat/_snaps/print_plot.md index 17742c92..564d6e7e 100644 --- a/tests/testthat/_snaps/print_plot.md +++ b/tests/testthat/_snaps/print_plot.md @@ -2,3 +2,54 @@ WAoAAAACAAQFAAACAwAAAAAOAAAAAT+2J8YDcP5s +# print.loo supports kfold with pareto-k diagnostics - calibrated + + Code + print(kfold1) + Output + + Based on 10-fold cross-validation. + + Estimate SE + elpd_kfold -285.0 9.2 + p_kfold 2.5 0.6 + kfoldic 570.0 18.4 + ------ + + All Pareto k estimates are good (k < 0.7). + See help('pareto-k-diagnostic') for details. + +# print.loo supports kfold with pareto-k diagnostics - miscalibrated + + Code + print(kfold1) + Output + + Based on 10-fold cross-validation. + + Estimate SE + elpd_kfold -5556.6 701.0 + p_kfold 358.2 108.5 + kfoldic 11113.1 1401.9 + ------ + + Pareto k diagnostic values: + Count Pct. Min. ESS + (-Inf, 0.7] (good) 245 93.5% 24 + (0.7, 1] (bad) 8 3.1% + (1, Inf) (very bad) 9 3.4% + See help('pareto-k-diagnostic') for details. + +# print.loo supports kfold without pareto-k diagnostics + + Code + print(kfold1) + Output + + Based on 10-fold cross-validation. + + Estimate SE + elpd_kfold -5556.6 701.0 + p_kfold 358.2 108.5 + kfoldic 11113.1 1401.9 + diff --git a/tests/testthat/data-for-tests/kfold-calibrated.Rds b/tests/testthat/data-for-tests/kfold-calibrated.Rds new file mode 100644 index 00000000..e20d23e6 Binary files /dev/null and b/tests/testthat/data-for-tests/kfold-calibrated.Rds differ diff --git a/tests/testthat/data-for-tests/kfold-miscalibrated.Rds b/tests/testthat/data-for-tests/kfold-miscalibrated.Rds new file mode 100644 index 00000000..c0b371ba Binary files /dev/null and b/tests/testthat/data-for-tests/kfold-miscalibrated.Rds differ diff --git a/tests/testthat/data-for-tests/kfold-notes.md b/tests/testthat/data-for-tests/kfold-notes.md new file mode 100644 index 00000000..c0da2a97 --- /dev/null +++ b/tests/testthat/data-for-tests/kfold-notes.md @@ -0,0 +1,31 @@ +## Test data for testing print method of `kfold` object + +### Case 1: All pareto-k values are good + +```{r} +set.seed(123) +dat <- dplyr::tibble( + x = rnorm(200), + y = 2 + 1.5 * x + rnorm(200, sd = 1) +) + +fit <- brm(y ~ x, data = dat, seed = 42) +kfold1 <- kfold(fit) +saveRDS(kfold, "kfold-calibrated.Rds") +``` + +### Case 2: Some pareto-k values are problematic + +```{r} +data(roaches, package = "rstanarm") +roaches$sqrt_roach1 <- sqrt(roaches$roach1) + +fit_p <- brm(y ~ sqrt_roach1 + treatment + senior + offset(log(exposure2)), + data = roaches, + family = poisson, + prior = prior(normal(0,1), class = b), + refresh = 0) + +kfold2 <- kfold(fit_p) +saveRDS(kfold2, "kfold-miscalibrated.Rds") +``` \ No newline at end of file diff --git a/tests/testthat/test_print_plot.R b/tests/testthat/test_print_plot.R index f69ed6f0..a991c57f 100644 --- a/tests/testthat/test_print_plot.R +++ b/tests/testthat/test_print_plot.R @@ -185,3 +185,24 @@ test_that("mcse_loo returns NA when it should", { test_that("mcse_loo errors if not psis_loo object", { expect_error(mcse_loo(psis1), "psis_loo") }) + +# print.loo kfold objects -------------------------------------------------- + +test_that("print.loo supports kfold with pareto-k diagnostics - calibrated", { + kfold1 <- readRDS("data-for-tests/kfold-calibrated.Rds") + + expect_snapshot(print(kfold1)) +}) + +test_that("print.loo supports kfold with pareto-k diagnostics - miscalibrated", { + kfold1 <- readRDS("data-for-tests/kfold-miscalibrated.Rds") + + expect_snapshot(print(kfold1)) +}) + +test_that("print.loo supports kfold without pareto-k diagnostics", { + kfold1 <- readRDS("data-for-tests/kfold-miscalibrated.Rds") + kfold1$diagnostics <- NULL + + expect_snapshot(print(kfold1)) +}) \ No newline at end of file