Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ S3method(print,compare.loo)
S3method(print,compare.loo_ss)
S3method(print,importance_sampling)
S3method(print,importance_sampling_loo)
S3method(print,kfold)
S3method(print,loo)
S3method(print,pareto_k_table)
S3method(print,pseudobma_bb_weights)
Expand Down
22 changes: 22 additions & 0 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,28 @@ print.importance_sampling <- function(x, digits = 1, plot_k = FALSE, ...) {
invisible(x)
}

#' @export
#' @rdname print.loo
print.kfold <- function(x, digits = 1, plot_k = FALSE, ...) {
print.loo(x, digits = digits, ...)

if ("diagnostics" %in% names(x)) {
cat("------\n")
S <- dim(x)[1]
k_threshold <- ps_khat_threshold(S)
if (length(pareto_k_ids(x, threshold = k_threshold))) {
cat("\n")
}
print(pareto_k_table(x), digits = digits)
cat(.k_help())

if (plot_k) {
graphics::plot(x, ...)
}
}
return(invisible(x))
}

# internal ----------------------------------------------------------------

#' Print dimensions of log-likelihood or log-weights matrix
Expand Down
3 changes: 3 additions & 0 deletions man/print.loo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions tests/testthat/_snaps/print_plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,54 @@

WAoAAAACAAQFAAACAwAAAAAOAAAAAT+2J8YDcP5s

# print.loo supports kfold with pareto-k diagnostics - calibrated

Code
print(kfold1)
Output

Based on 10-fold cross-validation.

Estimate SE
elpd_kfold -285.0 9.2
p_kfold 2.5 0.6
kfoldic 570.0 18.4
------

All Pareto k estimates are good (k < 0.7).
See help('pareto-k-diagnostic') for details.

# print.loo supports kfold with pareto-k diagnostics - miscalibrated

Code
print(kfold1)
Output

Based on 10-fold cross-validation.

Estimate SE
elpd_kfold -5556.6 701.0
p_kfold 358.2 108.5
kfoldic 11113.1 1401.9
------

Pareto k diagnostic values:
Count Pct. Min. ESS
(-Inf, 0.7] (good) 245 93.5% 24
(0.7, 1] (bad) 8 3.1% <NA>
(1, Inf) (very bad) 9 3.4% <NA>
See help('pareto-k-diagnostic') for details.

# print.loo supports kfold without pareto-k diagnostics

Code
print(kfold1)
Output

Based on 10-fold cross-validation.

Estimate SE
elpd_kfold -5556.6 701.0
p_kfold 358.2 108.5
kfoldic 11113.1 1401.9

Binary file not shown.
Binary file not shown.
31 changes: 31 additions & 0 deletions tests/testthat/data-for-tests/kfold-notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
## Test data for testing print method of `kfold` object

### Case 1: All pareto-k values are good

```{r}
set.seed(123)
dat <- dplyr::tibble(
x = rnorm(200),
y = 2 + 1.5 * x + rnorm(200, sd = 1)
)

fit <- brm(y ~ x, data = dat, seed = 42)
kfold1 <- kfold(fit)
saveRDS(kfold, "kfold-calibrated.Rds")
```

### Case 2: Some pareto-k values are problematic

```{r}
data(roaches, package = "rstanarm")
roaches$sqrt_roach1 <- sqrt(roaches$roach1)

fit_p <- brm(y ~ sqrt_roach1 + treatment + senior + offset(log(exposure2)),
data = roaches,
family = poisson,
prior = prior(normal(0,1), class = b),
refresh = 0)

kfold2 <- kfold(fit_p)
saveRDS(kfold2, "kfold-miscalibrated.Rds")
```
21 changes: 21 additions & 0 deletions tests/testthat/test_print_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,24 @@ test_that("mcse_loo returns NA when it should", {
test_that("mcse_loo errors if not psis_loo object", {
expect_error(mcse_loo(psis1), "psis_loo")
})

# print.loo kfold objects --------------------------------------------------

test_that("print.loo supports kfold with pareto-k diagnostics - calibrated", {
kfold1 <- readRDS("data-for-tests/kfold-calibrated.Rds")

expect_snapshot(print(kfold1))
})

test_that("print.loo supports kfold with pareto-k diagnostics - miscalibrated", {
kfold1 <- readRDS("data-for-tests/kfold-miscalibrated.Rds")

expect_snapshot(print(kfold1))
})

test_that("print.loo supports kfold without pareto-k diagnostics", {
kfold1 <- readRDS("data-for-tests/kfold-miscalibrated.Rds")
kfold1$diagnostics <- NULL

expect_snapshot(print(kfold1))
})
Loading