From 097d4dc910906478dc36ff909cf89e3e2fd9f7e9 Mon Sep 17 00:00:00 2001 From: amas Date: Sun, 9 Nov 2025 11:06:47 -0500 Subject: [PATCH 1/5] Add save_cmdstan_config=1 to cmdstan call --- cmdstanpy/cmdstan_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index 671db7a0..997df668 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -866,6 +866,7 @@ def compose_command( cmd.append(f'init={self.inits[idx]}') cmd.append('output') cmd.append(f'file={csv_file}') + cmd.append('save_cmdstan_config=1') if diagnostic_file: cmd.append(f'diagnostic_file={diagnostic_file}') if profile_file: From af05639803dc9cc6699bdbb83f419f046501e2e8 Mon Sep 17 00:00:00 2001 From: amas Date: Sun, 9 Nov 2025 11:13:30 -0500 Subject: [PATCH 2/5] Add config files to RunSet --- cmdstanpy/stanfit/runset.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cmdstanpy/stanfit/runset.py b/cmdstanpy/stanfit/runset.py index f88f55f3..ccbe4e66 100644 --- a/cmdstanpy/stanfit/runset.py +++ b/cmdstanpy/stanfit/runset.py @@ -56,6 +56,7 @@ def __init__( ) self._stdout_files, self._profile_files = [], [] self._csv_files, self._diagnostic_files = [], [] + self._config_files = [] # per-process output files if one_process_per_chain and chains > 1: @@ -63,6 +64,13 @@ def __init__( self.gen_file_name(".txt", extra="stdout", id=id) for id in self._chain_ids ] + self._config_files = [ + os.path.join( + self._outdir, f"{self._base_outfile}_{id}_config.json" + ) + for id in self._chain_ids + ] + if args.save_profile: self._profile_files = [ self.gen_file_name(".csv", extra="profile", id=id) @@ -70,6 +78,7 @@ def __init__( ] else: self._stdout_files = [self.gen_file_name(".txt", extra="stdout")] + self._config_files = [self.gen_file_name(".json", extra="config")] if args.save_profile: self._profile_files = [ self.gen_file_name(".csv", extra="profile") @@ -196,6 +205,13 @@ def stdout_files(self) -> list[str]: """ return self._stdout_files + @property + def config_files(self) -> list[str]: + """ + List of paths to CmdStan config json files. + """ + return self._config_files + def _check_retcodes(self) -> bool: """Returns ``True`` when all chains have retcode 0.""" return all(retcode == 0 for retcode in self._retcodes) From ea73385f6c19b7b24d3d182fe7f7aaa90fb6e501 Mon Sep 17 00:00:00 2001 From: amas Date: Sun, 9 Nov 2025 11:24:09 -0500 Subject: [PATCH 3/5] Clean up RunSet.__repr__, add config file --- cmdstanpy/stanfit/runset.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/cmdstanpy/stanfit/runset.py b/cmdstanpy/stanfit/runset.py index ccbe4e66..3764b8c2 100644 --- a/cmdstanpy/stanfit/runset.py +++ b/cmdstanpy/stanfit/runset.py @@ -102,25 +102,21 @@ def __init__( ] def __repr__(self) -> str: - repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format( - self._chains, self._chain_ids, self._num_procs - ) - repr = '{}\n cmd (chain 1):\n\t{}'.format(repr, self.cmd(0)) - repr = '{}\n retcodes={}'.format(repr, self._retcodes) - repr = f'{repr}\n per-chain output files (showing chain 1 only):' - repr = '{}\n csv_file:\n\t{}'.format(repr, self._csv_files[0]) + lines = [ + f"RunSet: chains={self._chains}, chain_ids={self._chain_ids}, " + f"num_processes={self._num_procs}", + f" cmd (chain 1):\n\t{self.cmd(0)}", + f" retcodes={self._retcodes}", + " per-chain output files (showing chain 1 only):", + f" csv_file:\n\t{self._csv_files[0] if self._csv_files else ''}", + ] if self._args.save_latent_dynamics: - repr = '{}\n diagnostics_file:\n\t{}'.format( - repr, self._diagnostic_files[0] - ) + lines.append(f" diagnostics_file:\n\t{self._diagnostic_files[0]}") if self._args.save_profile: - repr = '{}\n profile_file:\n\t{}'.format( - repr, self._profile_files[0] - ) - repr = '{}\n console_msgs (if any):\n\t{}'.format( - repr, self._stdout_files[0] - ) - return repr + lines.append(f" profile_file:\n\t{self._profile_files[0]}") + lines.append(f" console_msgs (if any):\n\t{self._stdout_files[0]}") + lines.append(f" config_files:\n\t{self._config_files[0]}") + return '\n'.join(lines) @property def model(self) -> str: From 10361f74ec4107fc1f2c95ddaab8e5659d181ed9 Mon Sep 17 00:00:00 2001 From: amas Date: Sun, 9 Nov 2025 11:30:19 -0500 Subject: [PATCH 4/5] Add runset config filename tests --- test/test_runset.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_runset.py b/test/test_runset.py index 4576d924..616b3519 100644 --- a/test/test_runset.py +++ b/test/test_runset.py @@ -30,6 +30,7 @@ def test_check_repr() -> None: assert 'csv_file' in repr(runset) assert 'console_msgs' in repr(runset) assert 'diagnostics_file' not in repr(runset) + assert 'config_file' in repr(runset) def test_check_retcodes() -> None: @@ -106,6 +107,11 @@ def test_output_filenames_one_proc_per_chain() -> None: stdout_file.endswith(f"_stdout_{id}.txt") for id, stdout_file in zip(chain_ids, runset.stdout_files) ) + assert len(runset.config_files) == len(chain_ids) + assert all( + config_file.endswith(f"_{id}_config.json") + for id, config_file in zip(chain_ids, runset.config_files) + ) cmdstan_args_other_files = CmdStanArgs( model_name='bernoulli', @@ -153,6 +159,8 @@ def test_output_filenames_threading() -> None: ) assert len(runset.stdout_files) == 1 assert runset.stdout_files[0].endswith("_stdout.txt") + assert len(runset.config_files) == 1 + assert runset.config_files[0].endswith("_config.json") cmdstan_args_other_files = CmdStanArgs( model_name='bernoulli', @@ -198,6 +206,7 @@ def test_output_filenames_single_chain() -> None: runset = RunSet(args=cmdstan_args, chains=1, one_process_per_chain=True) base_file = runset._base_outfile assert runset.stdout_files[0].endswith(f"{base_file}_stdout.txt") + assert runset.config_files[0].endswith(f"{base_file}_config.json") cmdstan_args_other_files = CmdStanArgs( model_name='bernoulli', From 5ac4dc3b3cc0f3480426098b8dbfc61d866cd897 Mon Sep 17 00:00:00 2001 From: amas Date: Sun, 9 Nov 2025 11:39:22 -0500 Subject: [PATCH 5/5] Add config output tests --- test/test_cmdstan_args.py | 12 ++++++++++++ test/test_sample.py | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/test/test_cmdstan_args.py b/test/test_cmdstan_args.py index b275fb9f..f25c9be4 100644 --- a/test/test_cmdstan_args.py +++ b/test/test_cmdstan_args.py @@ -808,3 +808,15 @@ def test_args_pathfinder_bad(arg: str, require_int: bool) -> None: args = PathfinderArgs(**{arg: 1.1}) # type: ignore with pytest.raises(ValueError): args.validate() + + +def test_save_cmdstan_config() -> None: + sampler_args = SamplerArgs() + cmdstan_args = CmdStanArgs( + model_name='bernoulli', + model_exe='', + chain_ids=[1, 2, 3, 4], + method_args=sampler_args, + ) + command = cmdstan_args.compose_command(0, csv_file="foo") + assert "save_cmdstan_config=1" in command diff --git a/test/test_sample.py b/test/test_sample.py index aacee74b..92d8fefa 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -2204,3 +2204,27 @@ def test_no_output_draws() -> None: mcmc = model.sample(data=data, iter_sampling=0, save_warmup=False, chains=2) draws = mcmc.draws() assert np.array_equal(draws, np.empty((0, 2, len(mcmc.column_names)))) + + +def test_config_output() -> None: + stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') + model = CmdStanModel(stan_file=stan) + fit = model.sample( + data=jdata, + chains=2, + seed=12345, + iter_warmup=100, + iter_sampling=200, + ) + assert all(os.path.exists(cf) for cf in fit.runset.config_files) + + # Config file naming differs when only a single chain is output + fit_one_chain = model.sample( + data=jdata, + chains=1, + seed=12345, + iter_warmup=100, + iter_sampling=200, + ) + assert all(os.path.exists(cf) for cf in fit_one_chain.runset.config_files)