diff --git a/DESCRIPTION b/DESCRIPTION index a569f8fc..891b7249 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -41,7 +41,7 @@ Depends: Imports: checkmate, data.table, - jsonlite (>= 1.2.0), + jsonlite (>= 1.8.7), posterior (>= 1.5.0), processx (>= 3.5.0), R6 (>= 2.4.0), @@ -55,6 +55,7 @@ Suggests: loo (>= 2.0.0), qs2, rmarkdown, - testthat (>= 2.1.0), + testthat (>= 3.3.0), Rcpp VignetteBuilder: knitr +Config/testthat/edition: 3 diff --git a/tests/testthat/answers/json-boolean.json b/tests/testthat/_snaps/json/json-boolean.json similarity index 100% rename from tests/testthat/answers/json-boolean.json rename to tests/testthat/_snaps/json/json-boolean.json diff --git a/tests/testthat/answers/json-df-matrix.json b/tests/testthat/_snaps/json/json-df-matrix.json similarity index 100% rename from tests/testthat/answers/json-df-matrix.json rename to tests/testthat/_snaps/json/json-df-matrix.json diff --git a/tests/testthat/answers/json-factor.json b/tests/testthat/_snaps/json/json-factor.json similarity index 100% rename from tests/testthat/answers/json-factor.json rename to tests/testthat/_snaps/json/json-factor.json diff --git a/tests/testthat/answers/json-integer.json b/tests/testthat/_snaps/json/json-integer.json similarity index 100% rename from tests/testthat/answers/json-integer.json rename to tests/testthat/_snaps/json/json-integer.json diff --git a/tests/testthat/answers/json-matrix-lists.json b/tests/testthat/_snaps/json/json-matrix-lists.json similarity index 100% rename from tests/testthat/answers/json-matrix-lists.json rename to tests/testthat/_snaps/json/json-matrix-lists.json diff --git a/tests/testthat/answers/json-table-array.json b/tests/testthat/_snaps/json/json-table-array.json similarity index 100% rename from tests/testthat/answers/json-table-array.json rename to tests/testthat/_snaps/json/json-table-array.json diff --git a/tests/testthat/answers/json-table-matrix.json b/tests/testthat/_snaps/json/json-table-matrix.json similarity index 100% rename from tests/testthat/answers/json-table-matrix.json rename to tests/testthat/_snaps/json/json-table-matrix.json diff --git a/tests/testthat/answers/json-table-vector.json b/tests/testthat/_snaps/json/json-table-vector.json similarity index 100% rename from tests/testthat/answers/json-table-vector.json rename to tests/testthat/_snaps/json/json-table-vector.json diff --git a/tests/testthat/answers/json-unboxing.json b/tests/testthat/_snaps/json/json-unboxing.json similarity index 100% rename from tests/testthat/answers/json-unboxing.json rename to tests/testthat/_snaps/json/json-unboxing.json diff --git a/tests/testthat/answers/json-vector-lists.json b/tests/testthat/_snaps/json/json-vector-lists.json similarity index 100% rename from tests/testthat/answers/json-vector-lists.json rename to tests/testthat/_snaps/json/json-vector-lists.json diff --git a/tests/testthat/_snaps/model-code-print.md b/tests/testthat/_snaps/model-code-print.md new file mode 100644 index 00000000..b0ba98fb --- /dev/null +++ b/tests/testthat/_snaps/model-code-print.md @@ -0,0 +1,21 @@ +# code() and print() methods work + + data { + int N; + array[N] int y; + } + parameters { + real theta; + } + model { + theta ~ beta(1, 1); // uniform prior on interval 0,1 + y ~ bernoulli(theta); + } + +--- + + c("data {", " int N;", " array[N] int y;", + "}", "parameters {", " real theta;", "}", + "model {", " theta ~ beta(1, 1); // uniform prior on interval 0,1", + " y ~ bernoulli(theta);", "}") + diff --git a/tests/testthat/answers/model-code-output.rds b/tests/testthat/answers/model-code-output.rds deleted file mode 100644 index 540d5b54..00000000 Binary files a/tests/testthat/answers/model-code-output.rds and /dev/null differ diff --git a/tests/testthat/answers/model-print-output.stan b/tests/testthat/answers/model-print-output.stan deleted file mode 100644 index 3b6099fc..00000000 --- a/tests/testthat/answers/model-print-output.stan +++ /dev/null @@ -1,11 +0,0 @@ -data { - int N; - array[N] int y; -} -parameters { - real theta; -} -model { - theta ~ beta(1, 1); // uniform prior on interval 0,1 - y ~ bernoulli(theta); -} diff --git a/tests/testthat/helper-custom-expectations.R b/tests/testthat/helper-custom-expectations.R index 47a6e38b..d0128753 100644 --- a/tests/testthat/helper-custom-expectations.R +++ b/tests/testthat/helper-custom-expectations.R @@ -11,7 +11,11 @@ expect_compilation <- function(mod, ...) { } if(!is.null(before_mtime)) { after_mtime <- file.mtime(mod$exe_file()) - expect(before_mtime != after_mtime, sprintf("Exe file '%s' has NOT changed, despite expecting (re)compilation", mod$exe_file())) + expect_gt( + after_mtime, + before_mtime, + sprintf("Exe file '%s' has NOT changed, despite expecting (re)compilation", mod$exe_file()) + ) } invisible(mod) } @@ -26,7 +30,11 @@ expect_call_compilation <- function(constructor_call) { fail(sprint("Model executable '%s' does not exist after compilation.", mod$exe_file())) } after_mtime <- file.mtime(mod$exe_file()) - expect(before_time <= after_mtime, sprintf("Exe file '%s' has old timestamp, despite expecting (re)compilation", mod$exe_file())) + expect_gt( + after_mtime, + before_time, + sprintf("Exe file '%s' has old timestamp, despite expecting (re)compilation", mod$exe_file()) + ) invisible(mod) } @@ -40,7 +48,7 @@ expect_no_recompilation <- function(mod, ...) { before_mtime <- file.mtime(mod$exe_file()) expect_interactive_message(mod$compile(...), "Model executable is up to date!") after_mtime <- file.mtime(mod$exe_file()) - expect(before_mtime == after_mtime, sprintf("Model executable '%s' has changed, despite expecting no recompilation", mod$exe_file())) + expect_true(before_mtime == after_mtime, sprintf("Model executable '%s' has changed, despite expecting no recompilation", mod$exe_file())) invisible(mod) } @@ -92,8 +100,19 @@ expect_gq_output <- function(object, num_chains = NULL) { } expect_interactive_message <- function(object, regexp = NULL) { - rlang::with_interactive(value = TRUE, - expect_message(object = object, regexp = regexp)) + object <- substitute(object) + env <- parent.frame() + value <- NULL + rlang::with_interactive(value = TRUE, { + expect_message( + object = { + value <- rlang::eval_bare(object, env) + value + }, + regexp = regexp + ) + }) + invisible(value) } expect_noninteractive_silent <- function(object) { diff --git a/tests/testthat/helper-mock-cli.R b/tests/testthat/helper-mock-cli.R index 60a9e52d..799e8d1d 100644 --- a/tests/testthat/helper-mock-cli.R +++ b/tests/testthat/helper-mock-cli.R @@ -1,8 +1,9 @@ real_wcr <- wsl_compatible_run with_mocked_cli <- function(code, compile_ret, info_ret) { - with_mocked_bindings( - code, + code <- substitute(code) + caller <- parent.frame() + local_mocked_bindings( wsl_compatible_run = function(command, args, ...) { if ( !is.null(command) @@ -17,8 +18,11 @@ with_mocked_cli <- function(code, compile_ret, info_ret) { } else { real_wcr(command = command, args = args, ...) } - } + }, + .package = "cmdstanr", + .env = caller ) + rlang::eval_bare(code, env = caller) } ######## Mock Compile Expectations ####### diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R new file mode 100644 index 00000000..fb80e9e7 --- /dev/null +++ b/tests/testthat/setup.R @@ -0,0 +1,17 @@ +cleanup_stan_artifacts <- function() { + all_files_in_stan <- list.files( + test_path("resources", "stan"), + full.names = TRUE, + recursive = TRUE + ) + files_to_remove <- all_files_in_stan[!grepl("\\.stan$", all_files_in_stan)] + + if (length(files_to_remove) > 0) { + unlink(files_to_remove, force = TRUE) + } + + invisible(files_to_remove) +} + +cleanup_stan_artifacts() +withr::defer(cleanup_stan_artifacts(), testthat::teardown_env()) diff --git a/tests/testthat/teardown-remove-files.R b/tests/testthat/teardown-remove-files.R deleted file mode 100644 index 349324bf..00000000 --- a/tests/testthat/teardown-remove-files.R +++ /dev/null @@ -1,8 +0,0 @@ -# remove any files that aren't .stan files from resources/stan, -# e.g. files created by $compile() -all_files_in_stan <- - list.files(test_path("resources", "stan"), - full.names = TRUE, - recursive = TRUE) -not_stan_programs <- !grepl(".stan$", all_files_in_stan) -file.remove(all_files_in_stan[not_stan_programs]) diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index 89c0faf3..0922c0f0 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -1,5 +1,3 @@ -context("read_cmdstan_csv") - set_cmdstan_path() fit_bernoulli_optimize <- testing_fit("bernoulli", method = "optimize", seed = 1234) fit_bernoulli_variational <- testing_fit("bernoulli", method = "variational", seed = 123) @@ -519,9 +517,10 @@ test_that("returning time works for read_cmdstan_csv", { test_that("time from read_cmdstan_csv matches time from fit$time()", { fit <- fit_bernoulli_thin_1 - expect_equivalent( + expect_equal( read_cmdstan_csv(fit$output_files())$time$chains, - fit$time()$chains + fit$time()$chains, + ignore_attr = TRUE ) }) diff --git a/tests/testthat/test-data.R b/tests/testthat/test-data.R index b024f5e6..8fb9f63f 100644 --- a/tests/testthat/test-data.R +++ b/tests/testthat/test-data.R @@ -1,5 +1,3 @@ -context("data-utils") - set_cmdstan_path() fit <- testing_fit("bernoulli", method = "sample", seed = 123) fit_vb <- testing_fit("bernoulli", method = "variational", seed = 123) diff --git a/tests/testthat/test-example.R b/tests/testthat/test-example.R index d8c91910..9f8a47cb 100644 --- a/tests/testthat/test-example.R +++ b/tests/testthat/test-example.R @@ -1,4 +1,15 @@ -context("cmdstanr_example") +stan_program <- " + data { + int N; + array[N] int y; + } + parameters { + real theta; + } + model { + y ~ bernoulli(theta); + } + " test_that("cmdstanr_example works", { fit_mcmc <- cmdstanr_example("logistic", chains = 2, force_recompile = TRUE) @@ -13,26 +24,10 @@ test_that("cmdstanr_example works", { fit_vb <- cmdstanr_example("logistic", method = "variational") checkmate::expect_r6(fit_vb, "CmdStanVB") - expect_output(print_example_program("schools"), "vector[J] theta", fixed=TRUE) expect_output(print_example_program("schools_ncp"), "vector[J] theta_raw", fixed=TRUE) }) - -# used in multiple tests below -stan_program <- " - data { - int N; - array[N] int y; - } - parameters { - real theta; - } - model { - y ~ bernoulli(theta); - } - " - test_that("write_stan_file writes Stan file correctly", { skip_if_not_installed("rlang") f1 <- write_stan_file(stan_program) @@ -47,17 +42,16 @@ test_that("write_stan_file writes Stan file correctly", { }) test_that("write_stan_file writes to specified directory and filename", { - dir <- file.path(test_path(), "answers") + dir <- withr::local_tempdir() + explicit_dir <- withr::local_tempdir() expect_equal(dirname(f1 <- write_stan_file(stan_program, dir = dir, basename = "pasta")), absolute_path(dir)) expect_equal(f2 <- write_stan_file(stan_program, dir = dir, basename = "fruit.stan"), absolute_path(file.path(dir, "fruit.stan"))) expect_equal(f3 <- write_stan_file(stan_program, dir = dir, basename = "vegetable"), absolute_path(file.path(dir, "vegetable.stan"))) # should add .stan extension if missing - expect_equal(f4 <- write_stan_file(stan_program, dir = tempdir(), basename = "test"), - absolute_path(file.path(tempdir(), "test.stan"))) - - try(file.remove(f1, f2, f3, f4), silent = TRUE) + expect_equal(f4 <- write_stan_file(stan_program, dir = explicit_dir, basename = "test"), + absolute_path(file.path(explicit_dir, "test.stan"))) }) test_that("write_stan_file creates dir if necessary", { @@ -69,7 +63,7 @@ test_that("write_stan_file creates dir if necessary", { test_that("write_stan_file by default creates the same file for the same Stan model", { skip_if_not_installed("rlang") - dir <- file.path(test_path(), "answers") + dir <- withr::local_tempdir() f1 <- write_stan_file(stan_program, dir = dir) mtime1 <- file.info(f1)$mtime @@ -95,24 +89,16 @@ test_that("write_stan_file by default creates the same file for the same Stan mo mtime5 <- file.info(f5)$mtime expect_true(mtime1 < mtime5) - - - try(file.remove(f1, f2, f4), silent = TRUE) }) test_that("cmdstanr_write_stan_file_dir option works", { base_dir <- tempdir() - test_dir <- file.path(base_dir, "option_test") - if (!dir.exists(test_dir)) { - dir.create(test_dir) - } - options("cmdstanr_write_stan_file_dir" = test_dir) - file <- write_stan_file(stan_program) - expect_equal(repair_path(dirname(file)), repair_path(test_dir)) - options("cmdstanr_write_stan_file_dir" = NULL) + test_dir <- withr::local_tempdir(pattern = "option_test") + local({ + withr::local_options(list("cmdstanr_write_stan_file_dir" = test_dir)) + file <- write_stan_file(stan_program) + expect_equal(repair_path(dirname(file)), repair_path(test_dir)) + }) file <- write_stan_file(stan_program) expect_equal(repair_path(dirname(file)), repair_path(base_dir)) - if (!dir.exists(test_dir)) { - file.remove(test_dir) - } }) diff --git a/tests/testthat/test-failed-chains.R b/tests/testthat/test-failed-chains.R index 7a67cf2c..b76e8700 100644 --- a/tests/testthat/test-failed-chains.R +++ b/tests/testthat/test-failed-chains.R @@ -1,4 +1,3 @@ -context("failed chains") set_cmdstan_path() stan_program <- testing_stan_file("chain_fails") stan_program_init_warnings <- testing_stan_file("init_warnings") diff --git a/tests/testthat/test-fit-gq.R b/tests/testthat/test-fit-gq.R index 33cd62be..4fc43f1f 100644 --- a/tests/testthat/test-fit-gq.R +++ b/tests/testthat/test-fit-gq.R @@ -1,5 +1,3 @@ -context("fitted-gq") - set_cmdstan_path() fit <- testing_fit("bernoulli", method = "sample", seed = 123) fit_gq <- testing_fit("bernoulli_ppc", method = "generate_quantities", seed = 123, fitted_params = fit) @@ -78,7 +76,7 @@ test_that("print() method works after gq", { fit_gq$print(variable = "unknown", max_rows = 20), "Can't find the following variable(s): unknown", fixed = TRUE - ) # unknown parameter + ) out <- capture.output(fit_gq$print("y_rep")) expect_length(out, 11) # columns names + 1 y_rep diff --git a/tests/testthat/test-fit-init.R b/tests/testthat/test-fit-init.R index 7b78c73a..e4aca759 100644 --- a/tests/testthat/test-fit-init.R +++ b/tests/testthat/test-fit-init.R @@ -1,4 +1,3 @@ -context("fitted-inits") set_cmdstan_path() data_list_schools <- testing_data("schools") diff --git a/tests/testthat/test-fit-laplace.R b/tests/testthat/test-fit-laplace.R index 8991a9e1..90415e93 100644 --- a/tests/testthat/test-fit-laplace.R +++ b/tests/testthat/test-fit-laplace.R @@ -1,5 +1,3 @@ -context("fitted-laplace") - set_cmdstan_path() fit_laplace <- testing_fit("logistic", method = "laplace", seed = 100) PARAM_NAMES <- c("alpha", "beta[1]", "beta[2]", "beta[3]") diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index 6876fd5d..aede1110 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -1,5 +1,3 @@ -context("fitted-mcmc") - set_cmdstan_path() fit_mcmc <- testing_fit("logistic", method = "sample", seed = 123, chains = 2) @@ -139,7 +137,7 @@ test_that("print() method works after mcmc", { fit$print(variable = "unknown", max_rows = 20), "Can't find the following variable(s): unknown", fixed = TRUE - ) # unknown parameter + ) out <- capture.output(fit$print("theta")) expect_length(out, 9) # columns names + 8 thetas @@ -331,23 +329,20 @@ test_that("loo works for all draws storage formats", { skip_if_not_installed("loo") fit <- testing_fit("bernoulli_log_lik") - options(cmdstanr_draws_format = "draws_array") + withr::local_options(list(cmdstanr_draws_format = "draws_array")) expect_s3_class(suppressWarnings(fit$loo()), "loo") - options(cmdstanr_draws_format = "draws_df") + withr::local_options(list(cmdstanr_draws_format = "draws_df")) expect_s3_class(suppressWarnings(fit$loo()), "loo") - options(cmdstanr_draws_format = "draws_matrix") + withr::local_options(list(cmdstanr_draws_format = "draws_matrix")) expect_s3_class(suppressWarnings(fit$loo()), "loo") - options(cmdstanr_draws_format = "draws_list") + withr::local_options(list(cmdstanr_draws_format = "draws_list")) expect_s3_class(suppressWarnings(fit$loo()), "loo") - options(cmdstanr_draws_format = "draws_rvars") + withr::local_options(list(cmdstanr_draws_format = "draws_rvars")) expect_s3_class(suppressWarnings(fit$loo()), "loo") - - # reset option - options(cmdstanr_draws_format = NULL) }) test_that("draws() works for different formats", { diff --git a/tests/testthat/test-fit-mle.R b/tests/testthat/test-fit-mle.R index cd87a214..c868a37c 100644 --- a/tests/testthat/test-fit-mle.R +++ b/tests/testthat/test-fit-mle.R @@ -1,5 +1,3 @@ -context("fitted-mle") - set_cmdstan_path() fit_mle <- testing_fit("logistic", method = "optimize", seed = 123) mod <- testing_model("bernoulli") diff --git a/tests/testthat/test-fit-shared.R b/tests/testthat/test-fit-shared.R index d422c47c..8f83be31 100644 --- a/tests/testthat/test-fit-shared.R +++ b/tests/testthat/test-fit-shared.R @@ -1,5 +1,3 @@ -context("fitted-shared-methods") - set_cmdstan_path() fits <- list() fits[["sample"]] <- testing_fit("logistic", method = "sample", diff --git a/tests/testthat/test-fit-vb.R b/tests/testthat/test-fit-vb.R index e701e951..f5d2f2c9 100644 --- a/tests/testthat/test-fit-vb.R +++ b/tests/testthat/test-fit-vb.R @@ -1,5 +1,3 @@ -context("fitted-vb") - set_cmdstan_path() fit_vb <- testing_fit("logistic", method = "variational", seed = 123) fit_vb_sci_not <- testing_fit("logistic", method = "variational", seed = 123, iter = 200000, adapt_iter = 100000) diff --git a/tests/testthat/test-install.R b/tests/testthat/test-install.R index 0f2b9965..ccb781cc 100644 --- a/tests/testthat/test-install.R +++ b/tests/testthat/test-install.R @@ -1,5 +1,3 @@ -context("install") - # avoid parallel on Mac due to strange intermittent TBB errors on Github Actions CORES <- if (os_is_macos()) 1 else 2 @@ -124,8 +122,7 @@ test_that("install_cmdstan() works with version and release_url", { test_that("toolchain checks on Unix work", { skip_if(os_is_windows()) - path_backup <- Sys.getenv("PATH") - Sys.setenv("PATH" = "") + withr::local_envvar(c("PATH" = "")) if (os_is_macos()) { err_msg_cpp <- "A suitable C++ compiler was not found. Please install the command line tools for Mac with 'xcode-select --install' or install Xcode from the app store. Then restart R and run cmdstanr::check_cmdstan_toolchain()." err_msg_make <- "The 'make' tool was not found. Please install the command line tools for Mac with 'xcode-select --install' or install Xcode from the app store. Then restart R and run cmdstanr::check_cmdstan_toolchain()." @@ -143,7 +140,6 @@ test_that("toolchain checks on Unix work", { err_msg_make, fixed = TRUE ) - Sys.setenv("PATH" = path_backup) }) test_that("clean and rebuild works", { @@ -314,8 +310,7 @@ test_that("rtools4x_toolchain_path prefers static-posix when available", { if (arch_is_aarch64()) "_AARCH64" else "", "_HOME" ) - fake_rtools_home <- tempfile(pattern = "rtools-home-pref-", tmpdir = tempdir(check = TRUE)) - on.exit(unlink(fake_rtools_home, recursive = TRUE), add = TRUE) + fake_rtools_home <- withr::local_tempdir(pattern = "rtools-home-pref-") dir.create(file.path(fake_rtools_home, "x86_64-w64-mingw32.static.posix", "bin"), recursive = TRUE, showWarnings = FALSE) dir.create(file.path(fake_rtools_home, "mingw64", "bin"), @@ -338,8 +333,7 @@ test_that("rtools4x_toolchain_path falls back to mingw64 for legacy layouts", { if (arch_is_aarch64()) "_AARCH64" else "", "_HOME" ) - fake_rtools_home <- tempfile(pattern = "rtools-home-fallback-", tmpdir = tempdir(check = TRUE)) - on.exit(unlink(fake_rtools_home, recursive = TRUE), add = TRUE) + fake_rtools_home <- withr::local_tempdir(pattern = "rtools-home-fallback-") dir.create(file.path(fake_rtools_home, "mingw64", "bin"), recursive = TRUE, showWarnings = FALSE) file.create(file.path(fake_rtools_home, "mingw64", "bin", "g++.exe")) @@ -359,8 +353,7 @@ test_that("rtools4x_toolchain_path prefers ABI-compatible legacy fallback", { if (arch_is_aarch64()) "_AARCH64" else "", "_HOME" ) - fake_rtools_home <- tempfile(pattern = "rtools-home-abi-", tmpdir = tempdir(check = TRUE)) - on.exit(unlink(fake_rtools_home, recursive = TRUE), add = TRUE) + fake_rtools_home <- withr::local_tempdir(pattern = "rtools-home-abi-") dir.create(file.path(fake_rtools_home, "mingw64", "bin"), recursive = TRUE, showWarnings = FALSE) dir.create(file.path(fake_rtools_home, "ucrt64", "bin"), @@ -369,24 +362,20 @@ test_that("rtools4x_toolchain_path prefers ABI-compatible legacy fallback", { file.create(file.path(fake_rtools_home, "ucrt64", "bin", "g++.exe")) withr::with_envvar(setNames(fake_rtools_home, env_var), { - with_mocked_bindings( - { - expect_equal( - rtools4x_toolchain_path(), - repair_path(file.path(fake_rtools_home, "mingw64", "bin")) - ) - }, - is_ucrt_toolchain = function() FALSE - ) - with_mocked_bindings( - { - expect_equal( - rtools4x_toolchain_path(), - repair_path(file.path(fake_rtools_home, "ucrt64", "bin")) - ) - }, - is_ucrt_toolchain = function() TRUE - ) + local({ + local_mocked_bindings(is_ucrt_toolchain = function() FALSE) + expect_equal( + rtools4x_toolchain_path(), + repair_path(file.path(fake_rtools_home, "mingw64", "bin")) + ) + }) + local({ + local_mocked_bindings(is_ucrt_toolchain = function() TRUE) + expect_equal( + rtools4x_toolchain_path(), + repair_path(file.path(fake_rtools_home, "ucrt64", "bin")) + ) + }) }) }) @@ -396,8 +385,7 @@ test_that("check_rtools4x_windows_toolchain reports checked toolchain paths", { if (arch_is_aarch64()) "_AARCH64" else "", "_HOME" ) - fake_rtools_home <- tempfile(pattern = "rtools-home-invalid-", tmpdir = tempdir(check = TRUE)) - on.exit(unlink(fake_rtools_home, recursive = TRUE), add = TRUE) + fake_rtools_home <- withr::local_tempdir(pattern = "rtools-home-invalid-") dir.create(file.path(fake_rtools_home, "usr", "bin"), recursive = TRUE, showWarnings = FALSE) file.create(file.path(fake_rtools_home, "usr", "bin", "make.exe")) @@ -423,81 +411,91 @@ test_that("check_rtools4x_windows_toolchain reports checked toolchain paths", { }) test_that("toolchain_PATH_env_var() handles missing and configured Rtools homes", { - with_mocked_bindings( - expect_null(toolchain_PATH_env_var()), - os_is_windows = function() FALSE - ) - with_mocked_bindings( - expect_null(toolchain_PATH_env_var()), - os_is_windows = function() TRUE, - rtools4x_home_path = function() "" - ) - with_mocked_bindings( + local({ + local_mocked_bindings(os_is_windows = function() FALSE) + expect_null(toolchain_PATH_env_var()) + }) + local({ + local_mocked_bindings( + os_is_windows = function() TRUE, + rtools4x_home_path = function() "" + ) + expect_null(toolchain_PATH_env_var()) + }) + local({ + local_mocked_bindings( + os_is_windows = function() TRUE, + rtools4x_home_path = function() "C:/rtools", + rtools4x_toolchain_path = function() "C:/rtools/ucrt64/bin", + repair_path = function(path) path + ) expect_equal( toolchain_PATH_env_var(), "C:/rtools/usr/bin;C:/rtools/ucrt64/bin" - ), - os_is_windows = function() TRUE, - rtools4x_home_path = function() "C:/rtools", - rtools4x_toolchain_path = function() "C:/rtools/ucrt64/bin", - repair_path = function(path) path - ) + ) + }) }) test_that("check_rtools4x_windows_toolchain reports missing Rtools and make", { - fake_rtools_home <- tempfile(pattern = "rtools-home-missing-", tmpdir = tempdir(check = TRUE)) - on.exit(unlink(fake_rtools_home, recursive = TRUE), add = TRUE) + fake_rtools_home <- withr::local_tempdir(pattern = "rtools-home-missing-") - with_mocked_bindings( + local({ + local_mocked_bindings( + rtools4x_home_path = function() "", + rtools4x_version = function() "44" + ) expect_error( check_rtools4x_windows_toolchain(), "restart R, and then run cmdstanr::check_cmdstan_toolchain()", fixed = TRUE - ), - rtools4x_home_path = function() "", - rtools4x_version = function() "44" - ) + ) + }) dir.create(file.path(fake_rtools_home, "usr", "bin"), recursive = TRUE, showWarnings = FALSE) - with_mocked_bindings( + local({ + local_mocked_bindings( + rtools4x_home_path = function() fake_rtools_home, + rtools4x_version = function() "44" + ) expect_error( check_rtools4x_windows_toolchain(), "restart R, and then run cmdstanr::check_cmdstan_toolchain()", fixed = TRUE - ), - rtools4x_home_path = function() fake_rtools_home, - rtools4x_version = function() "44" - ) + ) + }) }) test_that("check_rtools4x_windows_toolchain validates install path and empty candidates", { - with_mocked_bindings( + local({ + local_mocked_bindings( + rtools4x_home_path = function() "C:/Program Files/Rtools44", + rtools4x_version = function() "44" + ) expect_error( check_rtools4x_windows_toolchain(), "Please reinstall the appropriate Rtools version for this R installation to a valid path", fixed = TRUE - ), - rtools4x_home_path = function() "C:/Program Files/Rtools44", - rtools4x_version = function() "44" - ) + ) + }) - fake_rtools_home <- tempfile(pattern = "rtools-home-empty-", tmpdir = tempdir(check = TRUE)) - on.exit(unlink(fake_rtools_home, recursive = TRUE), add = TRUE) + fake_rtools_home <- withr::local_tempdir(pattern = "rtools-home-empty-") dir.create(file.path(fake_rtools_home, "usr", "bin"), recursive = TRUE, showWarnings = FALSE) file.create(file.path(fake_rtools_home, "usr", "bin", "make.exe")) - with_mocked_bindings( + local({ + local_mocked_bindings( + rtools4x_home_path = function() fake_rtools_home, + rtools4x_version = function() "44", + rtools4x_toolchain_candidates = function() character() + ) expect_error( check_rtools4x_windows_toolchain(), "restart R, and then run cmdstanr::check_cmdstan_toolchain()", fixed = TRUE - ), - rtools4x_home_path = function() fake_rtools_home, - rtools4x_version = function() "44", - rtools4x_toolchain_candidates = function() character() - ) + ) + }) }) test_that("check_cmdstan_toolchain(fix = TRUE) is deprecated", { diff --git a/tests/testthat/test-json.R b/tests/testthat/test-json.R index c2c2657c..0f7788f3 100644 --- a/tests/testthat/test-json.R +++ b/tests/testthat/test-json.R @@ -1,30 +1,26 @@ -context("json") +expect_json_snapshot <- function(path, name = basename(path)) { + expect_snapshot_file(path, name = name, cran = TRUE) +} test_that("JSON output unboxing works", { temp_file <- tempfile() N <- 10 write_stan_json(list(N = N), file = temp_file) - json_output <- readLines(temp_file) - expect_known_output(cat(json_output, sep = "\n"), - file = test_path("answers", "json-unboxing.json")) + expect_json_snapshot(temp_file, "json-unboxing.json") }) test_that("JSON output for boolean is correct", { temp_file <- tempfile() N <- c(TRUE, FALSE, TRUE) write_stan_json(list(N = N), file = temp_file) - json_output <- readLines(temp_file) - expect_known_output(cat(json_output, sep = "\n"), - file = test_path("answers", "json-boolean.json")) + expect_json_snapshot(temp_file, "json-boolean.json") }) test_that("JSON output for factors is correct", { temp_file <- tempfile() N <- factor(c(0,1,2,2,1,0), labels = c("c1", "c2", "c3")) write_stan_json(list(N = N), file = temp_file) - json_output <- readLines(temp_file) - expect_known_output(cat(json_output, sep = "\n"), - file = test_path("answers", "json-factor.json")) + expect_json_snapshot(temp_file, "json-factor.json") }) test_that("JSON output for integer vector is correct", { @@ -32,9 +28,7 @@ test_that("JSON output for integer vector is correct", { N <- c(1.0, 2.0, 3, 4) write_stan_json(list(N = N), file = temp_file) - json_output <- readLines(temp_file) - expect_known_output(cat(json_output, sep = "\n"), - file = test_path("answers", "json-integer.json")) + expect_json_snapshot(temp_file, "json-integer.json") }) test_that("JSON output for data frame and matrix is correct", { @@ -47,16 +41,9 @@ test_that("JSON output for data frame and matrix is correct", { write_stan_json(list(X = df), file = temp_file_df) write_stan_json(list(X = mat), file = temp_file_mat) - json_output_mat <- readLines(temp_file_df) - json_output_df <- readLines(temp_file_mat) - expect_identical(json_output_df, json_output_mat) + expect_identical(readLines(temp_file_df), readLines(temp_file_mat)) - # Floating-point error introduced in jsonlite 1.8.5 - # https://github.com/jeroen/jsonlite/issues/420 - if (packageVersion("jsonlite") != "1.8.5") { - expect_known_output(cat(json_output_df, sep = "\n"), - file = test_path("answers", "json-df-matrix.json")) - } + expect_json_snapshot(temp_file_df, "json-df-matrix.json") }) test_that("JSON output for list of vectors is correct", { @@ -64,9 +51,7 @@ test_that("JSON output for list of vectors is correct", { N <- list(c(1,2,3), c(4,5,6)) write_stan_json(list(N = N), file = temp_file) - json_output <- readLines(temp_file) - expect_known_output(cat(json_output, sep = "\n"), - file = test_path("answers", "json-vector-lists.json")) + expect_json_snapshot(temp_file, "json-vector-lists.json") }) test_that("JSON output for list of matrices is correct", { @@ -76,9 +61,7 @@ test_that("JSON output for list of matrices is correct", { matrix(5:8, nrow = 2, byrow = TRUE) ) write_stan_json(list(M = matrices), file = temp_file) - json_output <- readLines(temp_file) - expect_known_output(cat(json_output, sep = "\n"), - file = test_path("answers", "json-matrix-lists.json")) + expect_json_snapshot(temp_file, "json-matrix-lists.json") }) test_that("JSON output for table is correct", { @@ -86,19 +69,13 @@ test_that("JSON output for table is correct", { f <- factor(rep(1:4, each = 5)) write_stan_json(list(x = table(f)), file = temp_file) - json_output <- readLines(temp_file) - expect_known_output(cat(json_output, sep = "\n"), - file = test_path("answers", "json-table-vector.json")) + expect_json_snapshot(temp_file, "json-table-vector.json") write_stan_json(list(x = table(f, f)), file = temp_file) - json_output <- readLines(temp_file) - expect_known_output(cat(json_output, sep = "\n"), - file = test_path("answers", "json-table-matrix.json")) + expect_json_snapshot(temp_file, "json-table-matrix.json") write_stan_json(list(x = table(f, f, f)), file = temp_file) - json_output <- readLines(temp_file) - expect_known_output(cat(json_output, sep = "\n"), - file = test_path("answers", "json-table-array.json")) + expect_json_snapshot(temp_file, "json-table-array.json") }) test_that("write_stan_json errors if NAs", { diff --git a/tests/testthat/test-knitr.R b/tests/testthat/test-knitr.R index 7e4400c5..6141ac9c 100644 --- a/tests/testthat/test-knitr.R +++ b/tests/testthat/test-knitr.R @@ -1,5 +1,3 @@ -context("knitr engine") - test_that("eng_cmdstan throws correct errors", { skip_if_not_installed("knitr") expect_error(eng_cmdstan(list(output.var = 1)), "must be a character string") @@ -24,7 +22,7 @@ test_that("eng_cmdstan works", { )) expect_interactive_message(eng_cmdstan(opts), "Compiling Stan program") opts$eval <- FALSE - expect_silent(eng_cmdstan(opts)) + expect_noninteractive_silent(eng_cmdstan(opts)) }) test_that("register_knitr_engine works with and without override", { diff --git a/tests/testthat/test-model-code-print.R b/tests/testthat/test-model-code-print.R index e3409ed9..16fdaa81 100644 --- a/tests/testthat/test-model-code-print.R +++ b/tests/testthat/test-model-code-print.R @@ -1,13 +1,10 @@ -context("model-code-print") - set_cmdstan_path() stan_program <- testing_stan_file("bernoulli") mod <- testing_model("bernoulli") - test_that("code() and print() methods work", { - expect_known_output(mod$print(), file = test_path("answers", "model-print-output.stan")) - expect_known_value(mod$code(), file = test_path("answers", "model-code-output.rds")) + expect_snapshot_output(mod$print(), cran = TRUE) + expect_snapshot_value(mod$code(), style = "deparse", cran = TRUE) }) test_that("code() and print() still work if file is removed", { diff --git a/tests/testthat/test-model-compile.R b/tests/testthat/test-model-compile.R index 4bc60a95..bd9b74da 100644 --- a/tests/testthat/test-model-compile.R +++ b/tests/testthat/test-model-compile.R @@ -1,5 +1,3 @@ -context("model-compile") - set_cmdstan_path() stan_program <- cmdstan_example_file() mod <- cmdstan_model(stan_file = stan_program, compile = FALSE) @@ -78,7 +76,6 @@ test_that("compile() method overwrites binaries", { mod$compile(quiet = TRUE) old_time = file.mtime(mod$exe_file()) mod$compile(quiet = TRUE, force_recompile = TRUE) - new_time = expect_gt(file.mtime(mod$exe_file()), old_time) }) @@ -113,6 +110,7 @@ test_that("compilation works with include_paths", { }) test_that("name in STANCFLAGS is set correctly", { + local_reproducible_output() out <- utils::capture.output(mod$compile(quiet = FALSE, force_recompile = TRUE)) if(os_is_windows() && !os_is_wsl()) { out_no_name <- "bin/stanc.exe --name='bernoulli_model' --o" @@ -122,7 +120,14 @@ test_that("name in STANCFLAGS is set correctly", { out_name <- "bin/stanc --name='bernoulli2_model' --o" } expect_output(print(out), out_no_name) - out <- utils::capture.output(mod$compile(quiet = FALSE, force_recompile = TRUE, stanc_options = list(name = "bernoulli2_model"))) + + out <- utils::capture.output( + mod$compile( + quiet = FALSE, + force_recompile = TRUE, + stanc_options = list(name = "bernoulli2_model") + ) + ) expect_output(print(out), out_name) }) @@ -299,11 +304,9 @@ test_that("check_syntax() works", { stan_file <- testing_stan_file("bernoulli") mod_ok <- cmdstan_model(stan_file, compile = FALSE) - expect_true( - expect_message( - mod_ok$check_syntax(), - "Stan program is syntactically correct" - ) + expect_message( + mod_ok$check_syntax(), + "Stan program is syntactically correct" ) expect_message( mod_ok$check_syntax(quiet = TRUE), @@ -832,9 +835,9 @@ test_that("dirname of stan_file is used as include path if no other paths suppli }) test_that("STANCFLAGS from get_cmdstan_flags() are included in compile output", { + local_reproducible_output() real_get_cmdstan_flags <- get_cmdstan_flags - out <- with_mocked_bindings( - utils::capture.output(mod$compile(quiet = FALSE, force_recompile = TRUE)), + local_mocked_bindings( get_cmdstan_flags = function(flag_name) { if (identical(flag_name, "STANCFLAGS")) { c("--O1", "--warn-pedantic") @@ -843,6 +846,7 @@ test_that("STANCFLAGS from get_cmdstan_flags() are included in compile output", } } ) + out <- utils::capture.output(mod$compile(quiet = FALSE, force_recompile = TRUE)) if(os_is_windows() && !os_is_wsl()) { out_w_flags <- "bin/stanc.exe --name='bernoulli_model'[[:space:]]+--O1[[:space:]]+--warn-pedantic[[:space:]]+--o" } else { diff --git a/tests/testthat/test-model-data.R b/tests/testthat/test-model-data.R index 2f91e575..68eb4c74 100644 --- a/tests/testthat/test-model-data.R +++ b/tests/testthat/test-model-data.R @@ -1,4 +1,3 @@ -context("model-data") # see separate test-json for testing writing data to JSON set_cmdstan_path() @@ -33,5 +32,6 @@ test_that("empty data list doesn't error if no data block", { ) # would error if fitting failed + expect_no_error(fit$draws()) expect_silent(fit$draws()) }) diff --git a/tests/testthat/test-model-diagnose.R b/tests/testthat/test-model-diagnose.R index 76ab4a47..0ef2db68 100644 --- a/tests/testthat/test-model-diagnose.R +++ b/tests/testthat/test-model-diagnose.R @@ -1,5 +1,3 @@ -context("model-diagnose") - set_cmdstan_path() mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") @@ -34,11 +32,11 @@ ok_arg_sci_nota_values <- list( test_that("diagnose() method runs when all arguments specified validly", { # specifying all arguments validly fit1 <- do.call(mod$diagnose, ok_arg_values) - expect_is(fit1, "CmdStanDiagnose") + expect_s3_class(fit1, "CmdStanDiagnose") # leaving all at default (except 'data' and 'seed') fit2 <- mod$diagnose(data = data_list, seed = 123) - expect_is(fit2, "CmdStanDiagnose") + expect_s3_class(fit2, "CmdStanDiagnose") }) test_that("diagnose() method runs when arguments are specified in scientific notation", { @@ -46,7 +44,7 @@ test_that("diagnose() method runs when arguments are specified in scientific not # specifying all arguments validly fit1 <- do.call(mod$diagnose, ok_arg_sci_nota_values) - expect_is(fit1, "CmdStanDiagnose") + expect_s3_class(fit1, "CmdStanDiagnose") }) test_that("diagnose() method errors for any invalid argument before calling cmdstan", { diff --git a/tests/testthat/test-model-expose-functions.R b/tests/testthat/test-model-expose-functions.R index 9bcf1382..0cbc4f8b 100644 --- a/tests/testthat/test-model-expose-functions.R +++ b/tests/testthat/test-model-expose-functions.R @@ -1,5 +1,3 @@ -context("model-expose-functions") - # Standalone functions not expected to work on WSL yet skip_if(os_is_wsl()) diff --git a/tests/testthat/test-model-generate_quantities.R b/tests/testthat/test-model-generate_quantities.R index 7c641815..ec1924aa 100644 --- a/tests/testthat/test-model-generate_quantities.R +++ b/tests/testthat/test-model-generate_quantities.R @@ -1,5 +1,3 @@ -context("model-generate-quantities") - set_cmdstan_path() fit <- testing_fit("bernoulli", method = "sample", seed = 123) mod_gq <- testing_model("bernoulli_ppc") @@ -25,11 +23,11 @@ bad_arg_values <- list( test_that("generate_quantities() method runs when all arguments specified validly", { # specifying all arguments validly expect_gq_output(fit1 <- do.call(mod_gq$generate_quantities, ok_arg_values)) - expect_is(fit1, "CmdStanGQ") + expect_s3_class(fit1, "CmdStanGQ") # leaving all at default (except 'data') expect_gq_output(fit2 <- mod_gq$generate_quantities(fitted_params = fit, data = data_list)) - expect_is(fit2, "CmdStanGQ") + expect_s3_class(fit2, "CmdStanGQ") }) test_that("generate_quantities() method errors for any invalid argument before calling cmdstan", { diff --git a/tests/testthat/test-model-init.R b/tests/testthat/test-model-init.R index ebe54677..e2e30dfc 100644 --- a/tests/testthat/test-model-init.R +++ b/tests/testthat/test-model-init.R @@ -1,5 +1,3 @@ -context("model-init") - set_cmdstan_path() mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") @@ -240,7 +238,7 @@ test_that("error if init function specified incorrectly", { }) test_that("print message if not all parameters are initialized", { - options(cmdstanr_warn_inits = NULL) # should default to TRUE + withr::local_options(list(cmdstanr_warn_inits = NULL)) # should default to TRUE init_list <- list( list( alpha = 1 @@ -273,7 +271,7 @@ test_that("print message if not all parameters are initialized", { }) test_that("No message printed if options(cmdstanr_warn_inits=FALSE)", { - options(cmdstanr_warn_inits = FALSE) + withr::local_options(list(cmdstanr_warn_inits = FALSE)) expect_message( utils::capture.output(mod_logistic$optimize(data = data_list_logistic, init = list(list(a = 0)), seed = 123)), regexp = NA @@ -286,7 +284,6 @@ test_that("No message printed if options(cmdstanr_warn_inits=FALSE)", { utils::capture.output(mod_logistic$sample(data = data_list_logistic, init = list(list(alpha = 1),list(alpha = 1)), chains = 2, seed = 123)), regexp = NA ) - options(cmdstanr_warn_inits = TRUE) }) test_that("Initial values for single-element containers treated correctly", { diff --git a/tests/testthat/test-model-laplace.R b/tests/testthat/test-model-laplace.R index 00dc78bb..961883c8 100644 --- a/tests/testthat/test-model-laplace.R +++ b/tests/testthat/test-model-laplace.R @@ -1,5 +1,3 @@ -context("model-laplace") - set_cmdstan_path() mod <- testing_model("logistic") data_list <- testing_data("logistic") @@ -48,7 +46,7 @@ test_that("laplace() method errors for any invalid argument before calling cmdst test_that("laplace() runs when all arguments specified validly", { # specifying all arguments validly expect_laplace_output(fit1 <- do.call(mod$laplace, ok_arg_values)) - expect_is(fit1, "CmdStanLaplace") + expect_s3_class(fit1, "CmdStanLaplace") # check that correct arguments were indeed passed to CmdStan expect_equal(fit1$metadata()$refresh, ok_arg_values$refresh) @@ -61,7 +59,7 @@ test_that("laplace() runs when all arguments specified validly", { # leaving all at default (except 'data') expect_laplace_output(fit2 <- mod$laplace(data = data_list, seed = 123)) - expect_is(fit2, "CmdStanLaplace") + expect_s3_class(fit2, "CmdStanLaplace") }) test_that("laplace() all valid 'mode' inputs give same results", { @@ -72,12 +70,12 @@ test_that("laplace() all valid 'mode' inputs give same results", { fit3 <- mod$laplace(data = data_list, mode = NULL, seed = 100, refresh = 0) }) - expect_is(fit1, "CmdStanLaplace") - expect_is(fit2, "CmdStanLaplace") - expect_is(fit3, "CmdStanLaplace") - expect_is(fit1$mode(), "CmdStanMLE") - expect_is(fit2$mode(), "CmdStanMLE") - expect_is(fit3$mode(), "CmdStanMLE") + expect_s3_class(fit1, "CmdStanLaplace") + expect_s3_class(fit2, "CmdStanLaplace") + expect_s3_class(fit3, "CmdStanLaplace") + expect_s3_class(fit1$mode(), "CmdStanMLE") + expect_s3_class(fit2$mode(), "CmdStanMLE") + expect_s3_class(fit3$mode(), "CmdStanMLE") expect_equal(fit1$mode()$mle(), fit2$mode()$mle()) expect_equal(fit1$mode()$mle(), fit3$mode()$mle()) expect_equal(fit1$lp(), fit2$lp()) diff --git a/tests/testthat/test-model-methods.R b/tests/testthat/test-model-methods.R index 1e38ad25..c502af15 100644 --- a/tests/testthat/test-model-methods.R +++ b/tests/testthat/test-model-methods.R @@ -1,4 +1,3 @@ -context("model-methods") skip_if(os_is_wsl()) set_cmdstan_path() diff --git a/tests/testthat/test-model-optimize.R b/tests/testthat/test-model-optimize.R index a3bd89c6..204a5b73 100644 --- a/tests/testthat/test-model-optimize.R +++ b/tests/testthat/test-model-optimize.R @@ -1,5 +1,3 @@ -context("model-optimize") - set_cmdstan_path() mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") @@ -45,17 +43,17 @@ ok_arg_sci_nota_values <- list( test_that("optimize() method runs when all arguments specified validly", { # specifying all arguments validly expect_optim_output(fit1 <- do.call(mod$optimize, ok_arg_values)) - expect_is(fit1, "CmdStanMLE") + expect_s3_class(fit1, "CmdStanMLE") # leaving all at default (except 'data') expect_optim_output(fit2 <- mod$optimize(data = data_list, seed = 123)) - expect_is(fit2, "CmdStanMLE") + expect_s3_class(fit2, "CmdStanMLE") }) test_that("optimize() method runs when arguments are specified in scientific notation", { # specifying all arguments validly expect_optim_output(fit1 <- do.call(mod$optimize, ok_arg_sci_nota_values)) - expect_is(fit1, "CmdStanMLE") + expect_s3_class(fit1, "CmdStanMLE") }) test_that("optimize() warns if threads specified but not enabled", { diff --git a/tests/testthat/test-model-output_dir.R b/tests/testthat/test-model-output_dir.R index f5284827..1593113f 100644 --- a/tests/testthat/test-model-output_dir.R +++ b/tests/testthat/test-model-output_dir.R @@ -1,19 +1,14 @@ -context("model-output_dir-output-basename") - set_cmdstan_path() -sandbox <- file.path(tempdir(check = TRUE), "sandbox") -if (!dir.exists(sandbox)) { - dir.create(sandbox) - on.exit(unlink(sandbox, recursive = TRUE)) + +local_output_sandbox <- function(pattern = "sandbox", .local_envir = parent.frame()) { + withr::local_tempdir(pattern = pattern, .local_envir = .local_envir) } test_that("all fitting methods work with output_dir", { + sandbox <- local_output_sandbox() for (method in c("sample", "optimize", "variational")) { method_dir <- file.path(sandbox, method) - if (!dir.exists(method_dir)) { - dir.create(method_dir) - on.exit(unlink(method_dir, recursive = TRUE)) - } + dir.create(method_dir, recursive = TRUE, showWarnings = FALSE) # WSL models use internal WSL tempdir if (!os_is_wsl()) { @@ -77,6 +72,7 @@ test_that("all fitting methods work with output_dir", { }) test_that("error if output_dir is invalid", { + sandbox <- local_output_sandbox() expect_error( testing_fit("bernoulli", output_dir = "NOT_A_DIR"), "Directory 'NOT_A_DIR' does not exist", @@ -91,20 +87,17 @@ test_that("error if output_dir is invalid", { # FIXME: how do I create an unreadable file on windows? not_readable <- file.path(sandbox, "locked") dir.create(not_readable, mode = "220") + skip_if(file.access(not_readable, 4) == 0, + "temp filesystem does not support unreadable test directories") expect_error( testing_fit("bernoulli", output_dir = not_readable), "not readable" ) } - file.remove(list.files(sandbox, full.names = TRUE, recursive = TRUE)) }) test_that("output_dir works with trailing /", { - test_dir <- file.path(tempdir(check = TRUE), "output_dir") - if (dir.exists(test_dir)) { - unlink(test_dir, recursive = TRUE) - } - dir.create(test_dir) + test_dir <- withr::local_tempdir(pattern = "output_dir") fit <- testing_fit( "bernoulli", method = "sample", diff --git a/tests/testthat/test-model-pathfinder.R b/tests/testthat/test-model-pathfinder.R index 58058010..58ff48be 100644 --- a/tests/testthat/test-model-pathfinder.R +++ b/tests/testthat/test-model-pathfinder.R @@ -1,5 +1,3 @@ -context("model-pathfinder") - set_cmdstan_path() stan_program <- testing_stan_file("bernoulli") mod <- testing_model("bernoulli") @@ -103,15 +101,15 @@ expect_pathfinder_output <- function(object, num_chains = NULL) { test_that("Pathfinder Runs", { expect_pathfinder_output(fit <- mod$pathfinder(data=data_list, seed=1234, refresh = 0)) - expect_is(fit, "CmdStanPathfinder") + expect_s3_class(fit, "CmdStanPathfinder") }) test_that("pathfinder() method works with data files", { expect_pathfinder_output(fit_r <- mod$pathfinder(data = data_file_r)) - expect_is(fit_r, "CmdStanPathfinder") + expect_s3_class(fit_r, "CmdStanPathfinder") expect_pathfinder_output(fit_json <- mod$pathfinder(data = data_file_json)) - expect_is(fit_json, "CmdStanPathfinder") + expect_s3_class(fit_json, "CmdStanPathfinder") }) test_that("pathfinder() method works with init file", { @@ -132,7 +130,7 @@ test_that("pathfinder() method works with init function and default paths", { test_that("pathfinder() method runs when all arguments specified", { expect_pathfinder_output(fit <- do.call(mod$pathfinder, ok_arg_values)) - expect_is(fit, "CmdStanPathfinder") + expect_s3_class(fit, "CmdStanPathfinder") }) test_that("pathfinder() method runs when the stan file is removed", { @@ -147,7 +145,8 @@ test_that("pathfinder() method runs when the stan file is removed", { test_that("no error when checking estimates after failure", { fit <- cmdstanr_example("schools", method = "pathfinder", seed = 123) # optim always fails for this - expect_silent(fit$summary()) # no error + expect_no_error(fit$summary()) + expect_silent(fit$summary()) }) test_that("no output with show_messages = FALSE", { @@ -156,4 +155,3 @@ test_that("no output with show_messages = FALSE", { ) expect_equal(length(output), 0) }) - diff --git a/tests/testthat/test-model-sample-metric.R b/tests/testthat/test-model-sample-metric.R index 422442fa..5789e583 100644 --- a/tests/testthat/test-model-sample-metric.R +++ b/tests/testthat/test-model-sample-metric.R @@ -1,5 +1,3 @@ -context("model-sample-metric") - set_cmdstan_path() mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") diff --git a/tests/testthat/test-model-sample.R b/tests/testthat/test-model-sample.R index f9ce48b4..85c5dbb8 100644 --- a/tests/testthat/test-model-sample.R +++ b/tests/testthat/test-model-sample.R @@ -1,5 +1,3 @@ -context("model-sample") - set_cmdstan_path() stan_program <- testing_stan_file("bernoulli") mod <- testing_model("bernoulli") @@ -85,15 +83,15 @@ bad_arg_values_3 <- list( test_that("sample() method works with data list", { expect_sample_output(fit <- mod$sample(data = data_list, chains = 1), 1) - expect_is(fit, "CmdStanMCMC") + expect_s3_class(fit, "CmdStanMCMC") }) test_that("sample() method works with data files", { expect_sample_output(fit_r <- mod$sample(data = data_file_r, chains = 1), 1) - expect_is(fit_r, "CmdStanMCMC") + expect_s3_class(fit_r, "CmdStanMCMC") expect_sample_output(fit_json <- mod$sample(data = data_file_json, chains = 1), 1) - expect_is(fit_json, "CmdStanMCMC") + expect_s3_class(fit_json, "CmdStanMCMC") }) test_that("sample() method works with init file", { @@ -109,7 +107,7 @@ test_that("sample() method works with init file", { test_that("sample() method runs when all arguments specified", { expect_sample_output(fit <- do.call(mod$sample, ok_arg_values), 2) - expect_is(fit, "CmdStanMCMC") + expect_s3_class(fit, "CmdStanMCMC") }) test_that("sample() method runs when the stan file is removed", { @@ -179,14 +177,14 @@ test_that("sampling in parallel works", { }) test_that("mc.cores option detected", { - options(mc.cores = 3) + withr::local_options(list(mc.cores = 3)) expect_output( mod$sample(data = data_list, chains = 3), "Running MCMC with 3 parallel chains", fixed = TRUE ) - options(mc.cores = NULL) + withr::local_options(list(mc.cores = NULL)) expect_output( mod$sample(data = data_list, chains = 2), "Running MCMC with 2 sequential chains", @@ -198,7 +196,7 @@ test_that("sample() method runs when fixed_param = TRUE", { mod_fp$compile() expect_sample_output(fit_1000 <- mod_fp$sample(fixed_param = TRUE, iter_sampling = 1000), 4) - expect_is(fit_1000, "CmdStanMCMC") + expect_s3_class(fit_1000, "CmdStanMCMC") expect_equal(dim(fit_1000$draws()), c(1000,4,10)) expect_sample_output(fit_500 <- mod_fp$sample(fixed_param = TRUE, iter_sampling = 500), 4) @@ -221,15 +219,15 @@ test_that("sample() method runs when adapt_engaged = FALSE", { test_that("chain_ids work with sample()", { mod$compile() expect_sample_output(fit12 <- mod$sample(data = data_list, chains = 2, chain_ids = c(10,12))) - expect_is(fit12, "CmdStanMCMC") + expect_s3_class(fit12, "CmdStanMCMC") expect_equal(fit12$metadata()$id, c(10,12)) expect_sample_output(fit12 <- mod$sample(data = data_list, chains = 2, chain_ids = c(100,7))) - expect_is(fit12, "CmdStanMCMC") + expect_s3_class(fit12, "CmdStanMCMC") expect_equal(fit12$metadata()$id, c(100,7)) expect_sample_output(fit12 <- mod$sample(data = data_list, chains = 1, chain_ids = c(6))) - expect_is(fit12, "CmdStanMCMC") + expect_s3_class(fit12, "CmdStanMCMC") expect_equal(fit12$metadata()$id, c(6)) expect_error(mod$sample(data = data_list, chains = 1, chain_ids = c(0)), @@ -369,7 +367,7 @@ test_that("All output can be suppressed by show_messages", { stan_program <- testing_stan_file("bernoulli") data_list <- testing_data("bernoulli") mod <- cmdstan_model(stan_program, force_recompile = TRUE) - options("cmdstanr_verbose" = FALSE) + withr::local_options(list("cmdstanr_verbose" = FALSE)) output <- capture.output( fit <- mod$sample(data = data_list, show_messages = FALSE) ) diff --git a/tests/testthat/test-model-sample_mpi.R b/tests/testthat/test-model-sample_mpi.R index 1ea0f9a6..9463cbbb 100644 --- a/tests/testthat/test-model-sample_mpi.R +++ b/tests/testthat/test-model-sample_mpi.R @@ -1,5 +1,3 @@ -context("model-sample_mpi") - test_that("sample_mpi() works", { skip_if(!mpi_toolchain_present()) mpi_file <- write_stan_file(" @@ -35,9 +33,11 @@ test_that("sample_mpi() works", { if (os_is_wsl()) { # Default GHA WSL install runs as root, which MPI discourages # Specify that this is safe to ignore for this test - Sys.setenv("OMPI_ALLOW_RUN_AS_ROOT"=1) - Sys.setenv("OMPI_ALLOW_RUN_AS_ROOT_CONFIRM"=1) - Sys.setenv("WSLENV"="OMPI_ALLOW_RUN_AS_ROOT/u:OMPI_ALLOW_RUN_AS_ROOT_CONFIRM/u") + withr::local_envvar(c( + "OMPI_ALLOW_RUN_AS_ROOT" = "1", + "OMPI_ALLOW_RUN_AS_ROOT_CONFIRM" = "1", + "WSLENV" = "OMPI_ALLOW_RUN_AS_ROOT/u:OMPI_ALLOW_RUN_AS_ROOT_CONFIRM/u" + )) } utils::capture.output( diff --git a/tests/testthat/test-model-variables.R b/tests/testthat/test-model-variables.R index 5ca43ef2..ae71d29f 100644 --- a/tests/testthat/test-model-variables.R +++ b/tests/testthat/test-model-variables.R @@ -1,5 +1,3 @@ -context("model-variables") - set_cmdstan_path() test_that("$variables() work correctly with example models", { diff --git a/tests/testthat/test-model-variational.R b/tests/testthat/test-model-variational.R index 6c917d1e..25b1fc77 100644 --- a/tests/testthat/test-model-variational.R +++ b/tests/testthat/test-model-variational.R @@ -1,5 +1,3 @@ -context("model-variational") - set_cmdstan_path() mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") @@ -46,11 +44,11 @@ bad_arg_values <- list( test_that("variational() method runs when all arguments specified validly", { # specifying all arguments validly expect_vb_output(fit1 <- do.call(mod$variational, ok_arg_values)) - expect_is(fit1, "CmdStanVB") + expect_s3_class(fit1, "CmdStanVB") # leaving all at default (except data and seed) expect_vb_output(fit2 <- mod$variational(data = data_list, seed = 123)) - expect_is(fit2, "CmdStanVB") + expect_s3_class(fit2, "CmdStanVB") }) test_that("variational() warns if threads specified but not enabled", { diff --git a/tests/testthat/test-opencl.R b/tests/testthat/test-opencl.R index 92858141..44f5fc77 100644 --- a/tests/testthat/test-opencl.R +++ b/tests/testthat/test-opencl.R @@ -1,5 +1,3 @@ -context("opencl") - set_cmdstan_path() fit <- testing_fit("bernoulli", method = "sample", seed = 123, chains = 1) @@ -128,4 +126,3 @@ test_that("all methods run with valid opencl_ids", { expect_false(is.null(fit$metadata()$device)) expect_false(is.null(fit$metadata()$platform)) }) - diff --git a/tests/testthat/test-path.R b/tests/testthat/test-path.R index e9e7eda0..5bc25819 100644 --- a/tests/testthat/test-path.R +++ b/tests/testthat/test-path.R @@ -1,5 +1,3 @@ -context("paths") - Sys.unsetenv("CMDSTAN") PATH <- absolute_path(set_cmdstan_path()) VERSION <- cmdstan_version() @@ -28,7 +26,7 @@ test_that("Setting bad path leads to warning (can't find directory)", { test_that("Setting bad path from env leads to warning (can't find directory)", { unset_cmdstan_path() .cmdstanr$WSL <- TRUE - Sys.setenv(CMDSTAN = "BAD_PATH") + withr::local_envvar(c(CMDSTAN = "BAD_PATH")) expect_warning( cmdstanr_initialize(), "Can't find directory specified by environment variable" @@ -36,30 +34,26 @@ test_that("Setting bad path from env leads to warning (can't find directory)", { expect_null(.cmdstanr$PATH) expect_null(.cmdstanr$VERSION) expect_false(isTRUE(.cmdstanr$WSL)) - Sys.unsetenv("CMDSTAN") }) test_that("Setting path from env var is detected", { unset_cmdstan_path() expect_true(is.null(.cmdstanr$VERSION)) - Sys.setenv(CMDSTAN = PATH) + withr::local_envvar(c(CMDSTAN = PATH)) expect_silent(cmdstanr_initialize()) expect_equal(cmdstan_path(), PATH) expect_false(is.null(.cmdstanr$VERSION)) - Sys.unsetenv("CMDSTAN") }) test_that("Unsupported CmdStan path from env var is rejected", { unset_cmdstan_path() .cmdstanr$WSL <- TRUE - parent_dir <- file.path(tempdir(check = TRUE), "cmdstan-env-parent") + parent_dir <- withr::local_tempdir(pattern = "cmdstan-env-parent") old_install <- file.path(parent_dir, "cmdstan-2.34.0") dir.create(old_install, recursive = TRUE, showWarnings = FALSE) - on.exit(unlink(parent_dir, recursive = TRUE), add = TRUE) - on.exit(Sys.unsetenv("CMDSTAN"), add = TRUE) writeLines("CMDSTAN_VERSION := 2.34.0", con = file.path(old_install, "makefile")) - Sys.setenv(CMDSTAN = parent_dir) + withr::local_envvar(c(CMDSTAN = parent_dir)) suppressWarnings(cmdstanr_initialize()) expect_false(identical(.cmdstanr$PATH, absolute_path(old_install))) expect_false(identical(.cmdstanr$VERSION, "2.34.0")) @@ -74,12 +68,9 @@ test_that("Existing CMDSTAN env path with no install resets cached state", { .cmdstanr$PATH <- PATH .cmdstanr$VERSION <- VERSION .cmdstanr$WSL <- TRUE - empty_parent <- file.path(tempdir(check = TRUE), "cmdstan-empty-parent") - dir.create(empty_parent, recursive = TRUE, showWarnings = FALSE) - on.exit(unlink(empty_parent, recursive = TRUE), add = TRUE) - on.exit(Sys.unsetenv("CMDSTAN"), add = TRUE) + empty_parent <- withr::local_tempdir(pattern = "cmdstan-empty-parent") - Sys.setenv(CMDSTAN = empty_parent) + withr::local_envvar(c(CMDSTAN = empty_parent)) expect_warning( cmdstanr_initialize(), "No CmdStan installation found in the path specified by the environment variable 'CMDSTAN'.", @@ -129,7 +120,7 @@ test_that("cmdstan_version() behaves correctly when version is not set", { }) test_that("Warning message is thrown if can't detect version number", { - path <- testthat::test_path("answers") # valid path but not cmdstan + path <- withr::local_tempdir() # valid path but not cmdstan expect_warning( set_cmdstan_path(path), "Can't find CmdStan makefile to detect version number" @@ -146,9 +137,8 @@ test_that("Setting path rejects unsupported CmdStan versions", { .cmdstanr$WSL <- old_wsl }) - path <- file.path(tempdir(check = TRUE), "cmdstan-2.34.0") + path <- file.path(withr::local_tempdir(pattern = "cmdstan-unsupported"), "cmdstan-2.34.0") dir.create(path, recursive = TRUE, showWarnings = FALSE) - on.exit(unlink(path, recursive = TRUE), add = TRUE) writeLines("CMDSTAN_VERSION := 2.34.0", con = file.path(path, "makefile")) expect_warning( @@ -172,10 +162,9 @@ test_that("unset_cmdstan_path() also resets WSL state", { }) test_that("cmdstan_default_path() respects custom install directories", { - installs <- file.path(tempdir(check = TRUE), "cmdstan-custom-installs") + installs <- withr::local_tempdir(pattern = "cmdstan-custom-installs") dir.create(file.path(installs, "cmdstan-2.35.0"), recursive = TRUE, showWarnings = FALSE) dir.create(file.path(installs, "cmdstan-2.36.0"), recursive = TRUE, showWarnings = FALSE) - on.exit(unlink(installs, recursive = TRUE), add = TRUE) expect_equal( cmdstan_default_path(dir = installs), @@ -184,9 +173,7 @@ test_that("cmdstan_default_path() respects custom install directories", { }) test_that("cmdstan_default_path() returns NULL for empty custom install directories", { - installs <- file.path(tempdir(check = TRUE), "cmdstan-empty-installs") - dir.create(installs, recursive = TRUE, showWarnings = FALSE) - on.exit(unlink(installs, recursive = TRUE), add = TRUE) + installs <- withr::local_tempdir(pattern = "cmdstan-empty-installs") expect_null(cmdstan_default_path(dir = installs)) }) diff --git a/tests/testthat/test-profiling.R b/tests/testthat/test-profiling.R index 638f1fdc..dd45c45c 100644 --- a/tests/testthat/test-profiling.R +++ b/tests/testthat/test-profiling.R @@ -1,5 +1,3 @@ -context("profiling") - set_cmdstan_path() diff --git a/tests/testthat/test-threads.R b/tests/testthat/test-threads.R index fb5eec61..5e9f04f7 100644 --- a/tests/testthat/test-threads.R +++ b/tests/testthat/test-threads.R @@ -1,5 +1,3 @@ -context("threading") - set_cmdstan_path() stan_program <- testing_stan_file("bernoulli") stan_gq_program <- testing_stan_file("bernoulli_ppc") diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index a2c39f15..b27628c5 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -1,5 +1,3 @@ -context("utils") - set_cmdstan_path() fit_mcmc <- testing_fit("logistic", method = "sample", seed = 123, chains = 2)