Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 243 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ cmdstan_model <- function(stan_file = NULL, exe_file = NULL, compile = TRUE, ...
#' [`$hpp_file()`][model-method-compile] | Return the file path to the `.hpp` file containing the generated C++ code. |
#' [`$save_hpp_file()`][model-method-compile] | Save the `.hpp` file containing the generated C++ code. |
#' [`$expose_functions()`][model-method-expose_functions] | Expose Stan functions for use in R. |
#' [`$cmdstan_defaults()`][model-method-cmdstan_defaults] | Get CmdStan default argument values for a method. |
#'
#' ## Diagnostics
#'
Expand Down Expand Up @@ -2209,6 +2210,51 @@ expose_functions = function(global = FALSE, verbose = FALSE) {
CmdStanModel$set("public", name = "expose_functions", value = expose_functions)


#' Get CmdStan default argument values
#'
#' @name model-method-cmdstan_defaults
#' @aliases cmdstan_defaults
#' @family CmdStanModel methods
#'
#' @description The `$cmdstan_defaults()` method of a [`CmdStanModel`]
#' object queries the compiled model binary for the default argument
#' values used by a given inference method. The returned list uses
#' cmdstanr-style argument names (e.g., `iter_sampling` instead of
#' CmdStan's `num_samples`).
#'
#' The model must be compiled before calling this method.
#'
#' @param method (string) The inference method whose defaults to
#' retrieve. One of `"sample"`, `"optimize"`, `"variational"`,
#' `"pathfinder"`, or `"laplace"`.
#' @return A named list of default argument values for the specified
#' method, with cmdstanr-style argument names.
#'
#' @template seealso-docs
#'
#' @examples
#' \dontrun{
#' mod <- cmdstan_model(file.path(cmdstan_path(),
#' "examples/bernoulli/bernoulli.stan"))
#' mod$cmdstan_defaults("sample")
#' mod$cmdstan_defaults("optimize")
#' }
#'
cmdstan_defaults <- function(method = c("sample", "optimize", "variational",
"pathfinder", "laplace")) {
method <- match.arg(method)
if (length(self$exe_file()) == 0 || !file.exists(self$exe_file())) {
stop(
"'$cmdstan_defaults()' requires a compiled model. ",
"Please compile the model first with '$compile()'.",
call. = FALSE
)
}
parse_cmdstan_args(self$exe_file(), method)
}
CmdStanModel$set("public", name = "cmdstan_defaults", value = cmdstan_defaults)



# internal ----------------------------------------------------------------
assert_valid_stanc_options <- function(stanc_options) {
Expand Down Expand Up @@ -2289,10 +2335,10 @@ model_variables <- function(stan_file, include_paths = NULL, allow_undefined = F
variables
}


is_variables_method_supported <- function(mod) {
mod$has_stan_file() && file.exists(mod$stan_file())
}

resolve_exe_path <- function(dir = NULL,
private_dir = NULL,
self_exe_file = NULL,
Expand Down Expand Up @@ -2329,3 +2375,199 @@ resolve_exe_path <- function(dir = NULL,
}
exe
}

# cmdstan_defaults() helpers

#' Parse CmdStan default argument values from model binary
#'
#' Runs a CmdStan model binary with `help-all` to extract valid arguments
#' and their default values for a given inference method, returning them
#' with cmdstanr argument names.
#'
#' @noRd
#' @param model_binary Path to the CmdStan model binary.
#' @param method Inference method: `"sample"`, `"optimize"`,
#' `"variational"`, `"pathfinder"`, or `"laplace"`.
#' @return A named list with cmdstanr-style argument names and default
#' values.
parse_cmdstan_args <- function(model_binary, method) {
withr::with_path(
c(
toolchain_PATH_env_var(),
tbb_path()
),
ret <- wsl_compatible_run(
command = wsl_safe_path(model_binary),
args = c(method, "help-all"),
error_on_status = FALSE
)
)
# CmdStan may write help text to stdout or stderr depending on the platform
raw <- paste0(ret$stdout, ret$stderr)
output <- strsplit(raw, "\r?\n")[[1]]

argument_map <- map_cmdstan_to_cmdstanr(method)
cmdstan_keys <- unname(argument_map)
public_names <- names(argument_map)

defaults <- list()
n <- length(output)
# Track the current hierarchical argument key using section indentation.
section_indents <- integer(0)
section_names <- character(0)

for (i in seq_len(n)) {
line <- output[i]
content <- trimws(line)

# Skip blank lines so they don't reset the section stack
if (!nzchar(content)) next

indent <- nchar(sub("^(\\s*).*", "\\1", line))

# Drop sections at deeper or equal indentation
while (length(section_indents) > 0 &&
section_indents[[length(section_indents)]] >= indent) {
section_indents <- section_indents[-length(section_indents)]
section_names <- section_names[-length(section_names)]
}

section_name <- parse_cmdstan_section_name(content)
if (!is.null(section_name)) {
section_indents <- c(section_indents, indent)
section_names <- c(section_names, section_name)
next
}

arg_name <- parse_cmdstan_arg_name(content)
if (!is.null(arg_name)) {

# Build the full dotted argument key: method.section1.section2...arg_name
# The top-level method heading (e.g. "sample") is tracked as a section,
# so it becomes the first segment of the key.
full_key <- paste(c(section_names, arg_name), collapse = ".")

# Check if this full argument key matches one of our target arguments
match_idx <- match(full_key, cmdstan_keys, nomatch = 0L)

if (match_idx > 0L) {
default_value <- find_cmdstan_default_value(output, i, n)
defaults[[public_names[[match_idx]]]] <- default_value
}
}
}

defaults
}

#' Parse CmdStan section name from a help-all line
#' @noRd
parse_cmdstan_section_name <- function(line) {
match <- regmatches(line, regexec("^([a-z_][a-z0-9_]*)$", line))[[1]]
if (length(match) >= 2) match[2] else NULL
}

#' Parse CmdStan argument name from a help-all line
#' @noRd
parse_cmdstan_arg_name <- function(line) {
match <- regmatches(line, regexec("^([a-z_][a-z0-9_]*)=", line))[[1]]
if (length(match) >= 2) match[2] else NULL
}

#' Find CmdStan default value following a help-all argument line
#' @noRd
find_cmdstan_default_value <- function(output, line_idx, n_lines) {
default_value <- NULL

for (j in (line_idx + 1):min(line_idx + 5, n_lines)) {
next_content <- trimws(output[j])
if (grepl("^Defaults to", next_content)) {
default_value <- parse_default_value(next_content)
break
}
# Stop if we hit another argument
if (grepl("^[a-z_][a-z0-9_]*=", next_content)) break
}

default_value
}

#' Parse default value from "Defaults to ..." line
#' @noRd
parse_default_value <- function(line) {
val_str <- sub("^Defaults to\\s*", "", line)
if (val_str %in% c("true", "false")) return(val_str == "true")
if (grepl("^-?[0-9]+$", val_str)) return(as.integer(val_str))
if (grepl("^-?[0-9]*\\.?[0-9]+([eE][+-]?[0-9]+)?$", val_str)) return(as.numeric(val_str))
val_str
}

#' Map CmdStan argument names to CmdStanR argument names
#' @noRd
map_cmdstan_to_cmdstanr <- function(method) {
switch(method,
sample = c(
iter_sampling = "sample.num_samples",
iter_warmup = "sample.num_warmup",
save_warmup = "sample.save_warmup",
thin = "sample.thin",
adapt_engaged = "sample.adapt.engaged",
adapt_delta = "sample.adapt.delta",
init_buffer = "sample.adapt.init_buffer",
term_buffer = "sample.adapt.term_buffer",
window = "sample.adapt.window",
save_metric = "sample.adapt.save_metric",
max_treedepth = "sample.hmc.nuts.max_depth",
metric = "sample.hmc.metric",
metric_file = "sample.hmc.metric_file",
step_size = "sample.hmc.stepsize"
),
optimize = c(
algorithm = "optimize.algorithm",
jacobian = "optimize.jacobian",
iter = "optimize.iter",
init_alpha = "optimize.lbfgs.init_alpha",
tol_obj = "optimize.lbfgs.tol_obj",
tol_rel_obj = "optimize.lbfgs.tol_rel_obj",
tol_grad = "optimize.lbfgs.tol_grad",
tol_rel_grad = "optimize.lbfgs.tol_rel_grad",
tol_param = "optimize.lbfgs.tol_param",
history_size = "optimize.lbfgs.history_size"
),
variational = c(
algorithm = "variational.algorithm",
iter = "variational.iter",
grad_samples = "variational.grad_samples",
elbo_samples = "variational.elbo_samples",
eta = "variational.eta",
adapt_engaged = "variational.adapt.engaged",
adapt_iter = "variational.adapt.iter",
tol_rel_obj = "variational.tol_rel_obj",
eval_elbo = "variational.eval_elbo",
draws = "variational.output_samples"
),
pathfinder = c(
init_alpha = "pathfinder.init_alpha",
tol_obj = "pathfinder.tol_obj",
tol_rel_obj = "pathfinder.tol_rel_obj",
tol_grad = "pathfinder.tol_grad",
tol_rel_grad = "pathfinder.tol_rel_grad",
tol_param = "pathfinder.tol_param",
history_size = "pathfinder.history_size",
draws = "pathfinder.num_psis_draws",
num_paths = "pathfinder.num_paths",
save_single_paths = "pathfinder.save_single_paths",
psis_resample = "pathfinder.psis_resample",
calculate_lp = "pathfinder.calculate_lp",
max_lbfgs_iters = "pathfinder.max_lbfgs_iters",
single_path_draws = "pathfinder.num_draws",
num_elbo_draws = "pathfinder.num_elbo_draws"
),
laplace = c(
jacobian = "laplace.jacobian",
draws = "laplace.draws"
),
character(0)
)
}

1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ reference:
- read_cmdstan_csv
- write_stan_json
- write_stan_file
- print_stan_file
- draws_to_csv
- as_mcmc.list
- as_draws.CmdStanMCMC
Expand Down
1 change: 1 addition & 0 deletions man/CmdStanModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 10 additions & 10 deletions man/cmdstanr-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-check_syntax.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions man/model-method-cmdstan_defaults.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading