From 39ac270b21dbff97821ea20ea26b065f2e4d02e3 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 22 May 2026 16:32:26 +0200 Subject: [PATCH 01/21] Adapt aca-model to the pylcm #361 API restructure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adopt pylcm's public `lcm/` / private `_lcm/` package split and the accompanying API reorganisation: the `Regime` two-class split, the grid renames (`Piece` → `PiecewiseGridSegment`, `*Process` classes), the `FlatParams` rename, and the `regime` → `regime_id` / `regime_name` distinction. Declare `distributed=True` on the assets grid for multi-GPU sharding, activate the beartype claw on the package, and apply the boilerplate update (hatch-vcs versioning, refreshed pre-commit hooks, expanded `.gitignore`). Co-Authored-By: Claude Opus 4.7 --- .github/workflows/main.yml | 8 ++- .gitignore | 57 ++++++++++++++++-- .pre-commit-config.yaml | 16 ++++- pyproject.toml | 29 +++++++-- src/aca_model/__init__.py | 19 ++++++ .../_benchmark_data/benchmark_params.pkl | Bin 68562 -> 68025 bytes src/aca_model/_version.py | 1 - src/aca_model/agent/preferences.py | 5 +- src/aca_model/baseline/regimes/_common.py | 47 +++++++++------ src/aca_model/benchmark.py | 22 +++++-- src/aca_model/consumption_dollars_grid.py | 29 ++++----- src/aca_model/environment/social_security.py | 6 +- src/aca_model/environment/taxes.py | 4 +- tests/test_aca_policies.py | 32 +++++----- tests/test_beartype_claw.py | 25 ++++++++ tests/test_benchmark.py | 4 +- tests/test_budget_chain_integration.py | 8 +-- tests/test_health_insurance.py | 16 ++--- .../test_initial_conditions_extreme_assets.py | 10 +-- tests/test_model_components.py | 18 +++--- tests/test_model_creation.py | 12 ++-- tests/test_regime_transitions.py | 44 +++++++------- tests/test_social_security.py | 6 +- tests/test_ss_benefit_integration.py | 4 +- tests/test_ssi_medicaid_integration.py | 48 +++++++-------- tests/test_taxes.py | 16 ++--- 26 files changed, 317 insertions(+), 169 deletions(-) delete mode 100644 src/aca_model/_version.py create mode 100644 tests/test_beartype_claw.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8104668..e57ac0b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,5 +1,9 @@ --- name: main +# aca-model is a git submodule of the aca-dev workspace and has no pixi config +# of its own — the pixi environments live in the parent workspace, whose +# `tests-cpu` env has editable path-dependencies on private sibling repos that a +# standalone CI runner cannot clone. CI therefore installs with pip directly. concurrency: group: ${{ github.head_ref || github.run_id }} cancel-in-progress: true @@ -26,10 +30,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm (feature branch — revert to @main once pylcm#348/#350 merge) + - name: Install pylcm (pinned to the phase-2 branch until it merges to main) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@feat/runtime-grid-extra-params" + git+https://github.com/OpenSourceEconomics/pylcm.git@refactor/phase-2-api-reorganisation" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/.gitignore b/.gitignore index e2bab43..70259d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,59 @@ -__pycache__/ -*.py[cod] +# Claude Code +.claude/ + +# Distribution / packaging +*.egg *.egg-info/ -dist/ +*.manifest +*.spec +.eggs/ +.installed.cfg build/ -bld/ +dist/ +MANIFEST +sdist/ +wheels/ + +# IDE +.idea/ +.vscode/ + +# Jupyter / Jupyter Book +.ipynb_checkpoints/ +_build + +# macOS +.DS_Store + +# pixi .pixi/ +node_modules/ + +# pytask .pytask/ +.pytask.sqlite3 +bld/ +out/ +pytask.lock +pytask.lock.journal + +# Python +__pycache__/ +*.py[cod] +*.so +*$py.class + +# Ruff +.ruff_cache/ + +# Testing +.cache/ .coverage +.coverage.* +.hypothesis/ +.pytest_cache/ +coverage.xml htmlcov/ + +# Version file (generated by hatch-vcs) +src/*/_version.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3188ab..50b0b1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,9 @@ repos: - repo: meta hooks: - - id: check-hooks-apply + # check-hooks-apply is omitted: aca-model ships no notebooks yet, so the + # boilerplate nbstripout hook matches nothing and that meta check would + # fail. Re-add it once the repo gains a notebook. - id: check-useless-excludes - repo: https://github.com/tox-dev/pyproject-fmt rev: v2.21.1 @@ -37,6 +39,7 @@ repos: - id: name-tests-test args: - --pytest-test-first + exclude: ^tests/helpers/ - id: no-commit-to-branch args: - --branch @@ -46,6 +49,10 @@ repos: rev: v1.38.0 hooks: - id: yamllint + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.37.2 + hooks: + - id: check-github-workflows - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.15.12 hooks: @@ -61,6 +68,13 @@ repos: - jupyter - pyi - python + - repo: https://github.com/kynan/nbstripout + rev: 0.9.1 + hooks: + - id: nbstripout + args: + - --extra-keys + - metadata.kernelspec metadata.language_info.version metadata.vscode - repo: https://github.com/executablebooks/mdformat rev: 1.0.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 7bfbd0b..7ca2558 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,9 @@ [build-system] build-backend = "hatchling.build" -requires = [ "hatchling" ] +requires = [ "hatch-vcs", "hatchling" ] [project] name = "aca-model" -version = "0.0.0" description = "Core lifecycle model for the ACA structural retirement project." readme = { file = "README.md", content-type = "text/markdown" } keywords = [ @@ -23,8 +22,10 @@ classifiers = [ "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.14", ] +dynamic = [ "version" ] dependencies = [ "attrs", + "beartype", "cloudpickle", "dags", "estimagic", @@ -43,13 +44,19 @@ email = "hmgaudecker@uni-bonn.de" [[project.maintainers]] name = "Hans-Martin von Gaudecker" email = "hmgaudecker@uni-bonn.de" +[project.urls] +Github = "https://github.com/OpenSourceEconomics/aca-model" +Repository = "https://github.com/OpenSourceEconomics/aca-model" +Tracker = "https://github.com/OpenSourceEconomics/aca-model/issues" [tool.hatch] +build.hooks.vcs.version-file = "src/aca_model/_version.py" build.targets.sdist.exclude = [ "tests" ] build.targets.sdist.only-packages = true build.targets.wheel.only-include = [ "src" ] build.targets.wheel.sources = [ "src" ] metadata.allow-direct-references = true +version.source = "vcs" [tool.ruff] fix = true @@ -84,9 +91,21 @@ extend-ignore = [ "RUF002", # Ambiguous Unicode in docstrings (Greek letters in math) "RUF003", # Ambiguous Unicode in comments (Greek letters in math) ] -per-file-ignores."src/aca_model/models/*" = [ "E501" ] -per-file-ignores."task_*.py" = [ "ANN", "ARG001" ] -per-file-ignores."tests/*" = [ "D", "E501", "INP001", "PD011", "PLR2004", "S101" ] +per-file-ignores."src/aca_model/models/*" = [ + "E501", # Line too long (generated model files) +] +per-file-ignores."task_*.py" = [ + "ANN", # Type annotations (use ty instead) + "ARG001", # Unused function argument (pytask signatures) +] +per-file-ignores."tests/*" = [ + "D", # Docstrings + "E501", # Line too long + "INP001", # Implicit namespace package + "PD011", # Use of .values (false positives on non-pandas objects) + "PLR2004", # Magic value used in comparison + "S101", # Use of assert +] pydocstyle.convention = "google" [tool.pyproject-fmt] diff --git a/src/aca_model/__init__.py b/src/aca_model/__init__.py index ac63b78..1fc43f7 100644 --- a/src/aca_model/__init__.py +++ b/src/aca_model/__init__.py @@ -1,3 +1,22 @@ import jax jax.config.update("jax_enable_x64", True) + +# Import lcm before installing the claw so its `_jaxtyping_patch` (picklable +# jaxtyping sentinel) and `MappingProxyType` pytree registration are in place. +import lcm # noqa: E402, F401 + +# Install beartype's AST-rewriting claw on the whole `aca_model` package before +# any submodule is imported. The claw transforms each module's AST at first +# import to insert runtime type checks against its annotations; aca_model's +# numerical DAG/transition/utility functions are otherwise unchecked, since +# pylcm's own claw is scoped to `lcm.*`. Violations surface as beartype's +# `BeartypeCallHintViolation` — aca_model is an application, not a library with +# a documented exception contract. +from beartype import BeartypeConf, BeartypeStrategy # noqa: E402 +from beartype.claw import beartype_package # noqa: E402 + +beartype_package( + "aca_model", + conf=BeartypeConf(strategy=BeartypeStrategy.On, is_pep484_tower=True), +) diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index f7c505ef8fb94731af080e93df12a095b942dc27..15053c68abd1d7813b2188faf491cb90243bcdce 100644 GIT binary patch delta 1911 zcmbtSdrVtZ7(ZWoOIzB}!lr#e?S`2Z8yrp%$S5QYz1Ltevq=^ouyt|l?k?L5P>JCY zV%Vd`xztm<3`_--unwzDBjl z@5UoaR;#_uoE?{VG-92qELSgW5Hh(@dn1dwXIRXu*2c1HiQ3JD$KBbpz>+UiPwHjH z&3M*$#Ixn-Bp*V(RyX=ubfa?>ESkQ+qJ9_G_3if@3jV;w@F2wvEIRiYi;C=A*JUqP z{Z)vozA={($1j*GJX5vh}wHC>RE-oCYTy#}9S$VH>P0gQ{Y*qjm}+ClJ0T5z_Cm#ZKNGVHe$va&!|p+P|5put;5e*%i6}gR36lP~ zl#>MU*v-eGjf8I+z?x_wC^e&0BgE3lMVHVPok;P#9IT1X{*8BJ{X%lDLHzQqd|jFr z*?SAXk*1TA;+I}AdK+X=hQo^mSr$qZO30+aim9={r?gq^rA;s6);-L=x+n+KYOo%ofdQA{Fvxc`_5|2gtvPY-OsP0B)9tqgabD1>2g0l@U^$F$H_ zAFvB&=pSK}ezvs)wo;LLU-SOJR@&;PH0Rw%-?E2hDPree69&ZwfcN4W<@cjeAK=-Z zqf}B_VucJWO5sboO*vKJJt&IjWpFrtf_vg0l7syLy~ZpIybiEed`AIxcwaoLfVzx# f7k@otrAoN*Ko>3^)W8bti9953V!$T%dW-)7*n(Oi delta 2526 zcmbtTYfKzf6rQ^;mM#Gnwq+me0tFjrDZ331DV5Ib5(|m7B0h*s5kaZTqzjaX#8R@< zTIp()(mQlRe;Bh(8&U|hj7lpSpd<>hVg)OfH-=UtpdhWF1?!!eU6v)(@VLp$nS0K6 zzWbeX?p%X{E~`LLCamIJ5lo2qPY3RU|A^j#kpTfJOM*P^UVHbHDVvbDdGz{EOmz43x!->lP7E2iMGTI^Ga z>i?ej_6AxBqXZKgO`awq(Gs=*eHg0uD_p8S9KWHb{23{GIUMPZsYLaU*z@lZNh3#O ze}01LQ?6=Glt;Vi>k~(g^qGA0W6xgIP%j%~>m-orA*ISt$dLr+)wJs=V}3K9e>fXm z(8zf0Xan4}II0~85gn_dafRST&+QHG_7RX_xr%XVUv=Z=|e>yC#57=Qz(A8vc zd^g{=^OkaUxOTZcOjhf5=-9})VFFo3Dt8T!-g)F6R;zIxmxT@EO#dSz$L!bUk$ZYl z9bPLkMxEkiN5RTq4IR{Ba^r0g!`7X0|7~InGZJJlAH#~?p;9R0CZoKR3NliEivF1tBuzv?K@s5>FMPknysmpb$}zDik37nU@)MPRQxJ zAH)l2-v6Sfdi-em79V&kne3%JgFm83<&p?J)uRNGMR4eeRUICCS%A>|x{=Y$GqoYgIdIq&udh;g=v zft>OhQWHend9m59! zu=0UM9D+Qkfzei}0OTW7`b>@s7EaB`yi^Qmp?}G6Xif=FOE&Y!T$nN&4iahe>?jZH zwE{5+Uk&dYF1#g(SEVp9zI-x=F-7PS7o+FRlLhD~-}wgjsxKs5qv1%fa386lxs<&Ldj3j2l*L#WffAHh$qcbP)=o{az95t_V>y_UK9y6 zzD8bv+J|TC6kAqC0rUQqZQLw7+Mas(4(^fbx9{Wa<2Yq>l@q)IuBx-3di6(Jj{_hT ztl_xE2LR*&V6h(JgHUkJTEz!1taNS_)AGg$A?RAeysaj8AQO^jQS+@4JP^h&DvJ0I Du8P~M diff --git a/src/aca_model/_version.py b/src/aca_model/_version.py deleted file mode 100644 index 6c8e6b9..0000000 --- a/src/aca_model/_version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.0" diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 612896b..97fff2c 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -140,9 +140,10 @@ def u_alive( coefficient_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: - """Within-period utility for every non-dead regime: CES over consumption and leisure. + """Within-period utility for every non-dead regime. - `leisure` is a DAG input — supplied per-regime by `leisure_canwork_retiree_or_nongroup`, + CES over consumption and leisure. `leisure` is a DAG input — supplied + per-regime by `leisure_canwork_retiree_or_nongroup`, `leisure_canwork_tied`, or `leisure_forcedout`. """ composite = consumption_equiv**consumption_weight * leisure ** ( diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index b4d1e26..b783ecb 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -10,20 +10,21 @@ from typing import Any, Literal, TypedDict import jax.numpy as jnp -import lcm.shocks.ar1 -import lcm.shocks.iid import numpy as np +from _lcm.grids.continuous import ContinuousGrid from lcm import ( DiscreteGrid, IrregSpacedGrid, LinSpacedGrid, MarkovTransition, + NormalIIDProcess, + PiecewiseGridSegment, + PiecewiseLinSpacedGrid, Regime, + RouwenhorstAR1Process, categorical, ) -from lcm.grids.continuous import ContinuousGrid -from lcm.grids.piecewise import Piece, PiecewiseLinSpacedGrid -from lcm.typing import BoolND, FloatND, RegimeName, ScalarInt, UserParams +from lcm.typing import BoolND, FloatND, IntND, RegimeName, ScalarInt, UserParams from aca_model.agent import ( assets_and_income, @@ -34,7 +35,7 @@ from aca_model.agent.health import Health, HealthWithDisability from aca_model.agent.labor_market import LaborSupply, LaggedLaborSupply, SpousalIncome from aca_model.baseline import health_insurance -from aca_model.baseline.health_insurance import BuyPrivate, HealthInsuranceState +from aca_model.baseline.health_insurance import BuyPrivate from aca_model.config import MODEL_CONFIG, GridConfig from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -237,14 +238,14 @@ def build_grids( # grid to have unconditional variance 1, the Rouwenhorst innovation # std must be √(1 − ρ²). Passing the σ_y itself (≈0.577 for hcc, # 0.5627 for wage) would mis-scale the grid. - wage_res = lcm.shocks.ar1.Rouwenhorst( + wage_res = RouwenhorstAR1Process( n_points=grid_config.n_wage_res_gridpoints, rho=_WAGE_RHO, sigma=(1.0 - _WAGE_RHO**2) ** 0.5, mu=0.0, ) hcc_persistent = get_hcc_persistent_shock(grid_config=grid_config) - hcc_transitory = lcm.shocks.iid.Normal( + hcc_transitory = NormalIIDProcess( n_points=grid_config.n_hcc_transitory_gridpoints, gauss_hermite=True, mu=0.0, @@ -261,11 +262,11 @@ def build_grids( stop=500_000.0, n_points=grid_config.n_assets_gridpoints, batch_size=grid_config.n_assets_batch_size, + distributed=True, ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), consumption_dollars=IrregSpacedGrid( n_points=grid_config.n_consumption_dollars_gridpoints, - extra_param_names=("max_consumption_dollars",), ), wage_res=wage_res, hcc_persistent=hcc_persistent, @@ -274,7 +275,7 @@ def build_grids( ) -def get_hcc_persistent_shock(*, grid_config: GridConfig) -> lcm.shocks.ar1.Rouwenhorst: +def get_hcc_persistent_shock(*, grid_config: GridConfig) -> RouwenhorstAR1Process: """Return the persistent-HCC AR(1) shock grid for a given `grid_config`. Exposed so callers that need the shock's gridpoints / transition @@ -282,7 +283,7 @@ def get_hcc_persistent_shock(*, grid_config: GridConfig) -> lcm.shocks.ar1.Rouwe can derive them from `grid_config` alone without instantiating a full `Model`. """ - return lcm.shocks.ar1.Rouwenhorst( + return RouwenhorstAR1Process( n_points=grid_config.n_hcc_persistent_gridpoints, rho=_HCC_RHO, sigma=(1.0 - _HCC_RHO**2) ** 0.5, @@ -306,20 +307,26 @@ def _build_aime_grid( this path; the total is fixed by the PIA structure (32 points). """ kinks = [float(k) for k in np.asarray(fixed_params["pia_aime_grid"])] - pieces = ( - Piece(interval=f"[{kinks[0]}, {kinks[1]})", n_points=_AIME_PIECE_N_POINTS[0]), - Piece(interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1]), - Piece(interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2]), + segments = ( + PiecewiseGridSegment( + interval=f"[{kinks[0]}, {kinks[1]})", n_points=_AIME_PIECE_N_POINTS[0] + ), + PiecewiseGridSegment( + interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1] + ), + PiecewiseGridSegment( + interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2] + ), ) return PiecewiseLinSpacedGrid( - pieces=pieces, batch_size=grid_config.n_aime_batch_size + segments=segments, batch_size=grid_config.n_aime_batch_size ) def _compute_max_annual_labor_income( *, wage_params: Mapping[str, Any], - wage_res_grid: lcm.shocks.ar1.Rouwenhorst, + wage_res_grid: RouwenhorstAR1Process, ) -> float: """Return the annual labor income at the top of the wage grid. @@ -418,7 +425,7 @@ def build_actions(spec: RegimeSpec, grids: Grids) -> dict: return actions -def build_regime_probs(target: FloatND, survival: FloatND) -> FloatND: +def build_regime_probs(target: IntND, survival: FloatND) -> FloatND: """Build regime transition probability vector.""" probs = jnp.zeros(19) probs = probs.at[RegimeId.dead].set(1.0 - survival) @@ -603,10 +610,10 @@ def make_targets(name: str) -> tuple[dict[str, int], dict[str, int]]: def select_target_for_age( - next_age: int | FloatND, + next_age: int | IntND | FloatND, mc_next: bool | BoolND, tgts: dict[str, int], -) -> FloatND: +) -> IntND: """Select target regime ID based on next-period age bracket.""" ss_choose = jnp.where( jnp.array(mc_next), diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 5b519d5..a9ea000 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -100,19 +100,29 @@ def get_benchmark_params( ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: """Load the frozen `(fixed_params, wage_params, params)` snapshot. + `max_consumption_dollars` is popped out of `fixed_params` before + return — it's a grid-construction input read by + `inject_consumption_dollars_points`, not by any DAG function, so + leaving it in `fixed_params` would trip pylcm's unknown-keys check. + When `model` is provided, consumption_dollars gridpoints are injected into `params` for each regime that declares `consumption_dollars` as - an `IrregSpacedGrid` with runtime-supplied points. The lower bound is - read from `params["consumption_dollars_floor"]`. Pass `model=None` to - skip injection (e.g. when constructing the model with `fixed_params`). + an `IrregSpacedGrid` with runtime-supplied points. Pass `model=None` + to skip injection (e.g. when constructing the model with + `fixed_params`). """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) fixed_params = data["fixed_params"] wage_params = data["wage_params"] params = data["params"] + max_consumption_dollars = float(fixed_params.pop("max_consumption_dollars")) if model is not None: - params = inject_consumption_dollars_points(params=params, model=model) + params = inject_consumption_dollars_points( + params=params, + model=model, + max_consumption_dollars=max_consumption_dollars, + ) return fixed_params, wage_params, params @@ -133,7 +143,7 @@ def get_benchmark_initial_conditions( # Grid ranges come from any of the five regimes (shared structure). # Use to_jax() so the helper handles both LinSpacedGrid and # PiecewiseLinSpacedGrid (the latter has no `.start` / `.stop`). - ref_regime = model.regimes[_INITIAL_REGIMES[0]] + ref_regime = model.user_regimes[_INITIAL_REGIMES[0]] grids = ref_regime.states assets_pts = np.asarray(grids["assets"].to_jax()) aime_pts = np.asarray(grids["aime"].to_jax()) @@ -151,7 +161,7 @@ def get_benchmark_initial_conditions( ) return { - "regime": jnp.asarray(regime), + "regime_id": jnp.asarray(regime), "age": jnp.full(n_subjects, 51.0), "assets": jnp.asarray(rng.uniform(assets_lo, assets_hi, n_subjects)), "aime": jnp.asarray(rng.uniform(aime_lo, aime_hi, n_subjects)), diff --git a/src/aca_model/consumption_dollars_grid.py b/src/aca_model/consumption_dollars_grid.py index d99f8a9..362498c 100644 --- a/src/aca_model/consumption_dollars_grid.py +++ b/src/aca_model/consumption_dollars_grid.py @@ -1,14 +1,13 @@ """Runtime-supplied gridpoints for the consumption_dollars action. -Consumption is declared as `IrregSpacedGrid(n_points=N, -extra_param_names=("max_consumption_dollars",))` in +Consumption is declared as `IrregSpacedGrid(n_points=N)` in `baseline.regimes._common.build_grids` so the bounds can track runtime parameters: the lower bound from the per-iteration `consumption_equiv_floor` parameter (and its couples-scaled twin), -the upper bound from `max_consumption_dollars` carried through -`fixed_params` (per pylcm#348). Callers must inject the actual -gridpoints into `params` via `inject_consumption_dollars_points` -before calling `model.solve()` / `model.simulate()`. +the upper bound from `max_consumption_dollars` supplied directly +by the caller. Callers must inject the actual gridpoints into +`params` via `inject_consumption_dollars_points` before calling +`model.solve()` / `model.simulate()`. The grid pins the two regime-relevant transfer-floor levels exactly on the action grid so the borrowing constraint's @@ -32,6 +31,7 @@ def inject_consumption_dollars_points( *, params: Mapping[str, Any], model: Model, + max_consumption_dollars: float, ) -> dict[str, Any]: """Inject consumption_dollars gridpoints into per-regime params. @@ -40,7 +40,7 @@ def inject_consumption_dollars_points( The lower two gridpoints are the single and married Dollar-valued transfer floors; the rest are geomspaced from the married floor up - to `model.fixed_params["max_consumption_dollars"]`. + to `max_consumption_dollars`. Args: params: Existing params mapping with `consumption_equiv_floor` @@ -48,8 +48,11 @@ def inject_consumption_dollars_points( new dict; the input is not mutated. model: Model whose regimes carry the runtime-points grid and whose `fixed_params` supplies `exponent` (married - equivalence-scale exponent) and `max_consumption_dollars` - (grid upper bound). + equivalence-scale exponent). + max_consumption_dollars: Grid upper bound. Sourced from the + caller (e.g. aca-data's `environment_constants.pkl`); not + routed through pylcm's params machinery because no DAG + function consumes it. Returns: New params dict with consumption_dollars points injected. @@ -61,11 +64,9 @@ def inject_consumption_dollars_points( """ consumption_equiv_floor = jnp.asarray(params["consumption_equiv_floor"]) exponent = jnp.asarray(model.fixed_params["exponent"]) - max_consumption_dollars = jnp.asarray( - model.fixed_params["max_consumption_dollars"] - ) + max_consumption_dollars_arr = jnp.asarray(max_consumption_dollars) out: dict[str, Any] = dict(params) - for regime_name, regime in model.regimes.items(): + for regime_name, regime in model.user_regimes.items(): if regime.terminal: continue grid = regime.actions.get("consumption_dollars") @@ -88,7 +89,7 @@ def inject_consumption_dollars_points( points = _compute_consumption_dollars_points( consumption_equiv_floor=consumption_equiv_floor, exponent=exponent, - max_consumption_dollars=max_consumption_dollars, + max_consumption_dollars=max_consumption_dollars_arr, n_points=grid.n_points, ) regime_entry = dict(out.get(regime_name, {})) diff --git a/src/aca_model/environment/social_security.py b/src/aca_model/environment/social_security.py index 8b655d1..837ff5d 100644 --- a/src/aca_model/environment/social_security.py +++ b/src/aca_model/environment/social_security.py @@ -11,10 +11,12 @@ from lcm import categorical from lcm.typing import ( Age, + BoolND, ContinuousState, DiscreteAction, DiscreteState, FloatND, + IntND, Period, ScalarFloat, ScalarInt, @@ -211,8 +213,8 @@ def _apply_benefit_rules( pia: FloatND, age: Age, period: Period, - ss: FloatND, - work: FloatND, + ss: IntND, + work: BoolND, labor_income: FloatND, early_ret_adjustment: FloatND, normal_retirement_age: ScalarInt, diff --git a/src/aca_model/environment/taxes.py b/src/aca_model/environment/taxes.py index 8ed25e8..6d9e1d3 100644 --- a/src/aca_model/environment/taxes.py +++ b/src/aca_model/environment/taxes.py @@ -8,7 +8,7 @@ import jax.numpy as jnp from lcm.params import MappingLeaf -from lcm.typing import DiscreteState, FloatND +from lcm.typing import DiscreteState, FloatND, IntND def gross_income( @@ -136,7 +136,7 @@ def marginal_rate( return sched["marginal_rates"][spousal_income, bracket_id] -def _find_bracket(income: FloatND, upper_bounds: FloatND) -> FloatND: +def _find_bracket(income: FloatND, upper_bounds: FloatND) -> IntND: """Find the tax bracket index for a given income level.""" return jnp.searchsorted(upper_bounds, income, side="right") diff --git a/tests/test_aca_policies.py b/tests/test_aca_policies.py index 8ecc024..7ff22ae 100644 --- a/tests/test_aca_policies.py +++ b/tests/test_aca_policies.py @@ -63,7 +63,7 @@ def test_mandate_penalty_uninsured_above_exempt() -> None: income = jnp.array(40000.0) # 40000 * 0.025 = 1000, within [695, 2085] result = aca_hi.mandate_penalty( gross_income=income, - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), mandate_schedule=MANDATE_SCHEDULE, ) @@ -74,7 +74,7 @@ def test_mandate_penalty_insured_zero() -> None: """buy_private=yes produces no penalty.""" result = aca_hi.mandate_penalty( gross_income=jnp.array(40000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), mandate_schedule=MANDATE_SCHEDULE, ) @@ -85,7 +85,7 @@ def test_mandate_penalty_below_exempt_zero() -> None: """Income below exemption produces no penalty.""" result = aca_hi.mandate_penalty( gross_income=jnp.array(5000.0), # below 10350 - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), mandate_schedule=MANDATE_SCHEDULE, ) @@ -97,7 +97,7 @@ def test_mandate_penalty_clips_to_min() -> None: # 12000 * 0.025 = 300, below min of 695 result = aca_hi.mandate_penalty( gross_income=jnp.array(12000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), mandate_schedule=MANDATE_SCHEDULE, ) @@ -109,7 +109,7 @@ def test_mandate_penalty_clips_to_max() -> None: # 200000 * 0.025 = 5000, above max of 2085 result = aca_hi.mandate_penalty( gross_income=jnp.array(200000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), mandate_schedule=MANDATE_SCHEDULE, ) @@ -121,7 +121,7 @@ def test_hic_premium_subsidy_below_fpl_zero() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(10000.0), # below FPL_SINGLE - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -133,7 +133,7 @@ def test_hic_premium_subsidy_above_400_fpl_zero() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(50000.0), # above 4 * FPL_SINGLE = 47080 - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -148,7 +148,7 @@ def test_hic_premium_subsidy_at_200_fpl() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(premium), gross_income=jnp.array(income), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -160,7 +160,7 @@ def test_hic_premium_subsidy_uninsured_zero() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(2.0 * FPL_SINGLE), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -183,7 +183,7 @@ def test_cost_sharing_scale_brackets( """Verify each cost-sharing bracket produces the correct factor.""" result = aca_hi.cost_sharing( gross_income=jnp.array(income_fpl_frac * FPL_SINGLE), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), cost_sharing_schedule=COST_SHARING_SCHEDULE, ) @@ -194,7 +194,7 @@ def test_cost_sharing_scale_uninsured_one() -> None: """buy_private=no produces scale=1.0 (no reduction).""" result = aca_hi.cost_sharing( gross_income=jnp.array(1.2 * FPL_SINGLE), # would be 0.1721 if insured - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), cost_sharing_schedule=COST_SHARING_SCHEDULE, ) @@ -205,7 +205,7 @@ def test_medicaid_eligible_aca_below_threshold() -> None: """Income below 133% FPL produces eligible.""" result = aca_hi.is_medicaid_eligible( countable_income=jnp.array(10000.0), # below 15580 - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), medicaid_schedule=MEDICAID_SCHEDULE, ) assert result @@ -215,7 +215,7 @@ def test_medicaid_eligible_aca_above_threshold() -> None: """Income above 133% FPL produces not eligible.""" result = aca_hi.is_medicaid_eligible( countable_income=jnp.array(20000.0), # above 15580 - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), medicaid_schedule=MEDICAID_SCHEDULE, ) assert not result @@ -235,7 +235,7 @@ def test_premium_subsidy_exactly_at_100_fpl() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(1.0 * FPL_SINGLE), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -247,7 +247,7 @@ def test_premium_subsidy_exactly_at_400_fpl() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(4.0 * FPL_SINGLE), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -259,7 +259,7 @@ def test_premium_subsidy_just_below_400_fpl() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(4.0 * FPL_SINGLE - 1.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) diff --git a/tests/test_beartype_claw.py b/tests/test_beartype_claw.py new file mode 100644 index 0000000..2fb7873 --- /dev/null +++ b/tests/test_beartype_claw.py @@ -0,0 +1,25 @@ +"""The beartype claw is live on the `aca_model` package. + +Registering `beartype_package("aca_model", ...)` in `aca_model/__init__.py` +instruments every `aca_model` module at import time, so a type violation in +any aca_model function — including the numerical DAG leaf functions fed into +pylcm — is caught at the call boundary rather than slipping through against +a dishonest annotation. + +The test calls a real model-builder with one argument of the wrong type; the +`BeartypeCallHintViolation` is what proves the claw is installed. +""" + +import pytest +from beartype.roar import BeartypeCallHintViolation +from helpers.model import make_baseline_model + + +def test_claw_checks_aca_model() -> None: + """An ill-typed argument to an `aca_model` function is rejected by beartype. + + `create_model` annotates `n_subjects` as `int`; passing a string is caught + by the claw before the value reaches pylcm's own `Model` perimeter. + """ + with pytest.raises(BeartypeCallHintViolation): + make_baseline_model(n_subjects="not an int") # ty: ignore[invalid-argument-type] diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index b1be815..61b6f8b 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -29,7 +29,6 @@ def test_benchmark_model_simulates_end_to_end() -> None: initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) df = result.to_dataframe() @@ -69,11 +68,10 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) df = result.to_dataframe(additional_targets=["cash_on_hand", "equivalence_scale"]) - alive = df.loc[df["regime"] != "dead"].copy() + alive = df.loc[df["regime_name"] != "dead"].copy() consumption_dollars_floor = float(params["consumption_dollars_floor"]) floor = consumption_dollars_floor * alive["equivalence_scale"].to_numpy() rhs = np.maximum(alive["cash_on_hand"].to_numpy(), floor) diff --git a/tests/test_budget_chain_integration.py b/tests/test_budget_chain_integration.py index f087d16..03e5fbe 100644 --- a/tests/test_budget_chain_integration.py +++ b/tests/test_budget_chain_integration.py @@ -53,12 +53,12 @@ def test_working_agent_cash_on_hand() -> None: result = combined( assets=jnp.array(50000.0), - rate_of_return=0.03, + rate_of_return=jnp.asarray(0.03), labor_income=jnp.array(40000.0), spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(0.0), pension_benefit=jnp.array(0.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, payroll_tax_schedule=PAYROLL_TAX_SCHEDULE, ss_tax_schedule=SS_TAX_SCHEDULE, @@ -87,12 +87,12 @@ def test_retired_agent_with_pension() -> None: result = combined( assets=jnp.array(200000.0), - rate_of_return=0.03, + rate_of_return=jnp.asarray(0.03), labor_income=jnp.array(0.0), spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(15000.0), pension_benefit=jnp.array(10000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, payroll_tax_schedule=PAYROLL_TAX_SCHEDULE, ss_tax_schedule=SS_TAX_SCHEDULE, diff --git a/tests/test_health_insurance.py b/tests/test_health_insurance.py index 06a23b7..56dd7af 100644 --- a/tests/test_health_insurance.py +++ b/tests/test_health_insurance.py @@ -17,7 +17,7 @@ def test_ssi_eligible_assets_too_high() -> None: result = health_insurance.is_ssi_eligible( assets=jnp.array(5000.0), countable_income=jnp.array(1000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, @@ -29,7 +29,7 @@ def test_ssi_eligible_income_too_high() -> None: result = health_insurance.is_ssi_eligible( assets=jnp.array(1000.0), countable_income=jnp.array(9000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, @@ -41,7 +41,7 @@ def test_ssi_eligible_no_medicare() -> None: result = health_insurance.is_ssi_eligible( assets=jnp.array(1000.0), countable_income=jnp.array(1000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), gets_medicare=jnp.asarray(False), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, @@ -53,7 +53,7 @@ def test_ssi_eligible_all_pass() -> None: result = health_insurance.is_ssi_eligible( assets=jnp.array(1000.0), countable_income=jnp.array(1000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, @@ -64,7 +64,7 @@ def test_ssi_eligible_all_pass() -> None: def test_ssi_benefit_eligible() -> None: result = health_insurance.ssi_benefit( countable_income=jnp.array(3000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), is_ssi_eligible=jnp.array(True), ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -74,7 +74,7 @@ def test_ssi_benefit_eligible() -> None: def test_ssi_benefit_not_eligible() -> None: result = health_insurance.ssi_benefit( countable_income=jnp.array(3000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), is_ssi_eligible=jnp.array(False), ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -149,8 +149,8 @@ def test_compute_table_uniform_transition(table_inputs: dict) -> None: _PREMIUM_KWARGS: dict = { "age": jnp.int32(60), - "good_health": jnp.array(True), - "is_married": jnp.array(False), + "good_health": jnp.int32(1), + "is_married": jnp.int32(0), "labor_supply": jnp.array(LaborSupply.h2000), "premium_intercept": jnp.asarray(1000.0), "premium_age": jnp.asarray(0.0), diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 3b16522..17326fd 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -8,8 +8,8 @@ """ import jax.numpy as jnp +from _lcm.simulation.initial_conditions import validate_initial_conditions from lcm import DiscreteGrid -from lcm.simulation.initial_conditions import validate_initial_conditions from aca_model.agent.assets_and_income import borrowing_constraint from aca_model.agent.preferences import BenchmarkPrefType @@ -112,17 +112,17 @@ def test_extreme_negative_assets_subject_passes_validation() -> None: initial_conditions = { **initial_conditions, "assets": jnp.asarray([-1_000_000.0]), - "regime": jnp.asarray( + "regime_id": jnp.asarray( [model.regime_names_to_ids["retiree_nomc_inelig_canwork"]], dtype=jnp.int32, ), } - internal_params = model._process_params(params) # noqa: SLF001 + flat_params = model._process_params(params) # noqa: SLF001 validate_initial_conditions( initial_conditions=initial_conditions, - internal_regimes=model.internal_regimes, + regimes=model._regimes, # noqa: SLF001 regime_names_to_ids=model.regime_names_to_ids, - internal_params=internal_params, + flat_params=flat_params, ages=model.ages, ) diff --git a/tests/test_model_components.py b/tests/test_model_components.py index b3569c5..260c203 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -7,20 +7,20 @@ def test_equivalence_scale_single() -> None: - result = preferences.equivalence_scale(jnp.array(False), jnp.asarray(0.7)) + result = preferences.equivalence_scale(jnp.int32(0), jnp.asarray(0.7)) assert jnp.isclose(result, 1.0) def test_equivalence_scale_married() -> None: - result = preferences.equivalence_scale(jnp.array(True), jnp.asarray(0.7)) + result = preferences.equivalence_scale(jnp.int32(1), jnp.asarray(0.7)) assert jnp.isclose(result, 2.0**0.7) def test_leisure_not_working() -> None: result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(0.0), - good_health=jnp.array(1.0), - lagged_labor_supply=jnp.array(0), + good_health=jnp.int32(1), + lagged_labor_supply=jnp.int32(0), time_endowment=jnp.asarray(5000.0), leisure_cost_of_bad_health=jnp.asarray(500.0), fixed_cost_of_work=jnp.asarray(150.0), @@ -32,8 +32,8 @@ def test_leisure_not_working() -> None: def test_leisure_working_good_health() -> None: result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(2000.0), - good_health=jnp.array(1.0), - lagged_labor_supply=jnp.array(1), + good_health=jnp.int32(1), + lagged_labor_supply=jnp.int32(1), time_endowment=jnp.asarray(5000.0), leisure_cost_of_bad_health=jnp.asarray(500.0), fixed_cost_of_work=jnp.asarray(150.0), @@ -47,8 +47,8 @@ def test_leisure_working_good_health() -> None: def test_leisure_reentry_cost() -> None: result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(2000.0), - good_health=jnp.array(1.0), - lagged_labor_supply=jnp.array(0), + good_health=jnp.int32(1), + lagged_labor_supply=jnp.int32(0), time_endowment=jnp.asarray(5000.0), leisure_cost_of_bad_health=jnp.asarray(500.0), fixed_cost_of_work=jnp.asarray(150.0), @@ -60,7 +60,7 @@ def test_leisure_reentry_cost() -> None: def test_leisure_bad_health() -> None: result = preferences.leisure_forcedout( - good_health=jnp.array(0.0), + good_health=jnp.int32(0), time_endowment=jnp.asarray(5000.0), leisure_cost_of_bad_health=jnp.asarray(500.0), ) diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 7ae6e36..d19758c 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -43,7 +43,7 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: model = make_baseline_model(n_subjects=1) - assert len(model.regimes) == 19 + assert len(model.user_regimes) == 19 assert model.n_periods == 45 @@ -55,13 +55,13 @@ def test_model_age_range() -> None: def test_dead_regime_is_terminal() -> None: model = make_baseline_model(n_subjects=1) - assert model.regimes["dead"].terminal + assert model.user_regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: model = make_baseline_model(n_subjects=1) for name in REGIME_SPECS: - assert not model.regimes[name].terminal + assert not model.user_regimes[name].terminal def test_regime_id_dead_is_last() -> None: @@ -192,7 +192,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: model = make_aca_model(n_subjects=1, policy=PolicyVariant.ACA) - assert len(model.regimes) == 19 + assert len(model.user_regimes) == 19 assert model.n_periods == 45 @@ -233,7 +233,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" model = make_aca_model(n_subjects=1, policy=policy) - assert len(model.regimes) == 19 + assert len(model.user_regimes) == 19 def test_aca_no_medicaid_expansion_keeps_baseline_medicaid() -> None: @@ -273,4 +273,4 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" model = make_baseline_model(n_subjects=1) - assert len(model.regimes) == 19 + assert len(model.user_regimes) == 19 diff --git a/tests/test_regime_transitions.py b/tests/test_regime_transitions.py index 87cdeea..9040fdc 100644 --- a/tests/test_regime_transitions.py +++ b/tests/test_regime_transitions.py @@ -45,8 +45,8 @@ def test_tied_stop_working_becomes_nongroup() -> None: transition = tied_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.do_not_work), is_medicaid_eligible=jnp.array(False), survival_probs=SURVIVAL, @@ -62,8 +62,8 @@ def test_tied_keeps_working_stays_tied() -> None: transition = tied_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(False), survival_probs=SURVIVAL, @@ -81,8 +81,8 @@ def test_retiree_medicaid_override_to_nongroup() -> None: transition = retiree_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(True), survival_probs=SURVIVAL, @@ -97,8 +97,8 @@ def test_retiree_not_medicaid_stays_retiree() -> None: transition = retiree_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(False), survival_probs=SURVIVAL, @@ -113,8 +113,8 @@ def test_retiree_forcedout_medicaid_override() -> None: transition = retiree_forcedout(gets_medicare=True, own=own, ng=ng) probs = transition( - age=80, - period=29, + age=jnp.int32(80), + period=jnp.int32(29), is_medicaid_eligible=jnp.array(True), survival_probs=SURVIVAL, ) @@ -150,9 +150,9 @@ def test_retiree_age_bracket_transitions( own, ng = make_targets("retiree_nomc_inelig_canwork") transition = retiree_canwork(gets_medicare=False, own=own, ng=ng) - period = int(age - MODEL_CONFIG.start_age) + period = jnp.int32(age - MODEL_CONFIG.start_age) probs = transition( - age=age, + age=jnp.asarray(age), period=period, labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(False), @@ -171,8 +171,8 @@ def test_nongroup_canwork_valid_probs() -> None: transition = nongroup_canwork(gets_medicare=False, own=own) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), survival_probs=SURVIVAL, ) @@ -186,8 +186,8 @@ def test_nongroup_forcedout_valid_probs() -> None: transition = nongroup_forcedout(gets_medicare=True, own=own) probs = transition( - age=80, - period=29, + age=jnp.int32(80), + period=jnp.int32(29), survival_probs=SURVIVAL, ) assert jnp.isclose(jnp.sum(probs), 1.0, atol=1e-6) @@ -203,8 +203,8 @@ def test_tied_medicaid_override_to_nongroup() -> None: transition = tied_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(True), survival_probs=SURVIVAL, @@ -218,9 +218,9 @@ def test_tied_at_medicare_age_with_medicaid() -> None: own, ng = make_targets("tied_nomc_choose_canwork") transition = tied_canwork(gets_medicare=False, own=own, ng=ng) - period = int(64 - MODEL_CONFIG.start_age) + period = jnp.int32(64 - MODEL_CONFIG.start_age) probs = transition( - age=64, + age=jnp.int32(64), period=period, labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(True), @@ -238,8 +238,8 @@ def test_survival_prob_determines_death_weight() -> None: survival = jnp.ones(N_PERIODS) * 0.85 probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(False), survival_probs=survival, diff --git a/tests/test_social_security.py b/tests/test_social_security.py index c4c704f..e06912f 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -350,7 +350,7 @@ def test_benefit_inelig_pre65_disabled_below_sga() -> None: ) result = social_security.benefit_inelig_pre65( ssdi_pia=ssdi_val, - health=jnp.array(0), # disabled + health=jnp.int32(0), # disabled labor_income=jnp.array(0.0), ssdi_substantial_gainful_activity=SSDI_SGA, ) @@ -368,7 +368,7 @@ def test_benefit_inelig_pre65_disabled_above_sga() -> None: ) result = social_security.benefit_inelig_pre65( ssdi_pia=ssdi_val, - health=jnp.array(0), # disabled + health=jnp.int32(0), # disabled labor_income=jnp.array(20000.0), ssdi_substantial_gainful_activity=SSDI_SGA, ) @@ -379,7 +379,7 @@ def test_benefit_inelig_pre65_not_disabled() -> None: """Non-disabled agent: benefit = 0.""" result = social_security.benefit_inelig_pre65( ssdi_pia=jnp.array(1000.0), - health=jnp.array(2), # good health + health=jnp.int32(2), # good health labor_income=jnp.array(0.0), ssdi_substantial_gainful_activity=SSDI_SGA, ) diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 81a1b61..bd0e9b2 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -56,7 +56,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), - health=jnp.array(2), + health=jnp.int32(2), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(30000.0), early_ret_adjustment=jnp.array([0.75]), @@ -74,7 +74,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), - health=jnp.array(2), + health=jnp.int32(2), labor_supply=jnp.array(LaborSupply.do_not_work), labor_income=jnp.array(0.0), early_ret_adjustment=jnp.array([0.75]), diff --git a/tests/test_ssi_medicaid_integration.py b/tests/test_ssi_medicaid_integration.py index 345d751..3bc6d92 100644 --- a/tests/test_ssi_medicaid_integration.py +++ b/tests/test_ssi_medicaid_integration.py @@ -35,11 +35,11 @@ def test_low_income_qualifies_for_ssi_and_medicaid() -> None: spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(500.0), pension_benefit=jnp.array(0.0), - ssi_ignored_overall=20.0, - ssi_ignored_earned=65.0, + ssi_ignored_overall=jnp.asarray(20.0), + ssi_ignored_earned=jnp.asarray(65.0), assets=jnp.array(1000.0), - spousal_income=jnp.array(0), - gets_medicare=True, + spousal_income=jnp.int32(0), + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -66,11 +66,11 @@ def test_high_income_ineligible_for_ssi() -> None: spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(2000.0), pension_benefit=jnp.array(0.0), - ssi_ignored_overall=20.0, - ssi_ignored_earned=65.0, + ssi_ignored_overall=jnp.asarray(20.0), + ssi_ignored_earned=jnp.asarray(65.0), assets=jnp.array(1000.0), - spousal_income=jnp.array(0), - gets_medicare=True, + spousal_income=jnp.int32(0), + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -95,11 +95,11 @@ def test_no_medicare_blocks_ssi_under_baseline() -> None: spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(0.0), pension_benefit=jnp.array(0.0), - ssi_ignored_overall=20.0, - ssi_ignored_earned=65.0, + ssi_ignored_overall=jnp.asarray(20.0), + ssi_ignored_earned=jnp.asarray(65.0), assets=jnp.array(100.0), - spousal_income=jnp.array(0), - gets_medicare=False, + spousal_income=jnp.int32(0), + gets_medicare=jnp.asarray(False), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -119,26 +119,26 @@ def test_medicaid_reduces_oop() -> None: oop_medicaid = combined( total_health_costs=jnp.array(10000.0), buy_private=jnp.array(BuyPrivate.yes), - deductible=500.0, - coinsurance_rate=0.2, - oop_max=5000.0, + deductible=jnp.asarray(500.0), + coinsurance_rate=jnp.asarray(0.2), + oop_max=jnp.asarray(5000.0), is_ssi_eligible=jnp.array(True), - deductible_medicaid=100.0, - coinsurance_rate_medicaid=0.05, - oop_max_medicaid=1000.0, + deductible_medicaid=jnp.asarray(100.0), + coinsurance_rate_medicaid=jnp.asarray(0.05), + oop_max_medicaid=jnp.asarray(1000.0), ) # Not Medicaid-eligible: primary OOP only oop_no_medicaid = combined( total_health_costs=jnp.array(10000.0), buy_private=jnp.array(BuyPrivate.yes), - deductible=500.0, - coinsurance_rate=0.2, - oop_max=5000.0, + deductible=jnp.asarray(500.0), + coinsurance_rate=jnp.asarray(0.2), + oop_max=jnp.asarray(5000.0), is_ssi_eligible=jnp.array(False), - deductible_medicaid=100.0, - coinsurance_rate_medicaid=0.05, - oop_max_medicaid=1000.0, + deductible_medicaid=jnp.asarray(100.0), + coinsurance_rate_medicaid=jnp.asarray(0.05), + oop_max_medicaid=jnp.asarray(1000.0), ) assert oop_medicaid < oop_no_medicaid diff --git a/tests/test_taxes.py b/tests/test_taxes.py index 958eb0b..862c024 100644 --- a/tests/test_taxes.py +++ b/tests/test_taxes.py @@ -224,7 +224,7 @@ def test_taxable_ss_benefit_below_threshold() -> None: spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(5000.0), pension_benefit=jnp.array(0.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), ss_tax_schedule=SS_TAX_SCHEDULE, ) # Provisional income = 10000 + 0.5*5000 = 12500, below 25000 threshold @@ -235,7 +235,7 @@ def test_gross_income_basic() -> None: result = taxes.gross_income( capital_income=jnp.array(1000.0), labor_income=jnp.array(5000.0), - spousal_income=jnp.array(1), + spousal_income=jnp.int32(1), spousal_income_amounts=jnp.array([0.0, 2000.0, 20000.0]), taxable_ss_benefit=jnp.array(500.0), pension_benefit=jnp.array(300.0), @@ -247,7 +247,7 @@ def test_after_tax_income_zero() -> None: gi = taxes.gross_income( capital_income=jnp.array(0.0), labor_income=jnp.array(0.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), taxable_ss_benefit=jnp.array(0.0), pension_benefit=jnp.array(0.0), @@ -257,7 +257,7 @@ def test_after_tax_income_zero() -> None: ss_benefit=jnp.array(0.0), taxable_ss_benefit=jnp.array(0.0), labor_income=jnp.array(0.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, payroll_tax_schedule=PAYROLL_TAX_SCHEDULE, ) @@ -269,7 +269,7 @@ def test_after_tax_income_low_bracket() -> None: gi = taxes.gross_income( capital_income=jnp.array(0.0), labor_income=jnp.array(gross), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), taxable_ss_benefit=jnp.array(0.0), pension_benefit=jnp.array(0.0), @@ -279,7 +279,7 @@ def test_after_tax_income_low_bracket() -> None: ss_benefit=jnp.array(0.0), taxable_ss_benefit=jnp.array(0.0), labor_income=jnp.array(gross), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, payroll_tax_schedule=PAYROLL_TAX_SCHEDULE, ) @@ -293,7 +293,7 @@ def test_after_tax_income_low_bracket() -> None: def test_marginal_tax_rate_low_bracket() -> None: result = taxes.marginal_rate( gross_income=jnp.array(5000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, ) # 5000 is in bracket 1 (0-6200), rate = 0.0765 @@ -303,7 +303,7 @@ def test_marginal_tax_rate_low_bracket() -> None: def test_marginal_tax_rate_mid_bracket() -> None: result = taxes.marginal_rate( gross_income=jnp.array(10000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, ) # 10000 is in bracket 2 (6200-15275), rate = 0.199 From 97c11ee882b63c60038eebc6012fb7c679476923 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 22 May 2026 20:44:19 +0200 Subject: [PATCH 02/21] Declare the dead regime's assets grid non-distributed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The terminal `dead` regime carries only a tiny `[pref_type, assets]` value function. Inheriting the distributed `assets` grid made its V-array topology claim a sharded assets axis, while the solver emits a replicated array — a mismatch that surfaces as an opaque XLA sharding error mid-solve on multi-GPU runs. Sharding a terminal 3x24 V-array buys nothing; declare `assets` non-distributed for `dead`. Co-Authored-By: Claude Opus 4.7 --- src/aca_model/baseline/regimes/_common.py | 2 +- tests/test_model_creation.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index b783ecb..f4b68bd 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -449,7 +449,7 @@ def build_dead_regime(grids: Grids) -> Regime: "utility_scale_factor": preferences.utility_scale_factor, }, states={ - "assets": grids.assets, + "assets": grids.assets.replace(distributed=False), "pref_type": grids.pref_type, }, active=lambda _age: True, diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index d19758c..f948702 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -58,6 +58,18 @@ def test_dead_regime_is_terminal() -> None: assert model.user_regimes["dead"].terminal +def test_dead_regime_assets_grid_not_distributed() -> None: + """The terminal `dead` regime declares its `assets` grid non-distributed. + + `dead` carries only a tiny `[pref_type, assets]` value function; + sharding its assets axis across devices buys nothing and makes its + V-array sharding disagree with what the multi-GPU solve produces for + the regimes that transition into it. + """ + dead = build_regime("dead") + assert dead.states["assets"].distributed is False + + def test_non_terminal_regimes_not_terminal() -> None: model = make_baseline_model(n_subjects=1) for name in REGIME_SPECS: From ec29319d15186730953bf337fbfdb8d91dfe380e Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 23 May 2026 16:43:04 +0200 Subject: [PATCH 03/21] Thread subjects_batch_size through create_model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `GridConfig.subjects_batch_size_by_log_level` is a mapping from `log_level` string (`"off"`, `"warning"`, `"progress"`, `"debug"`) to the per-device simulate chunk size. Empty by default — the lookup helper `GridConfig.get_subjects_batch_size(log_level)` returns 0, matching the existing no-chunking behaviour. `create_model` (both baseline and aca variants) gains an optional `subjects_batch_size` keyword, forwarded to the pylcm `Model`. Callers in aca-estimation look up the value via `grid_config.get_subjects_batch_size(log_level)` keyed on the same `log_level` they pass to `model.simulate(...)`, so each task automatically gets the chunk size sized for its diagnostic budget. Co-Authored-By: Claude Opus 4.7 --- src/aca_model/aca/model.py | 6 ++++++ src/aca_model/baseline/model.py | 6 ++++++ src/aca_model/config.py | 19 ++++++++++++++++++- 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 6d39ac6..22437dd 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -25,6 +25,7 @@ def create_model( derived_categoricals: Mapping[str, DiscreteGrid], grid_config: GridConfig, pref_type_grid: DiscreteGrid, + subjects_batch_size: int = 0, ) -> Model: """Create an ACA policy variant model. @@ -45,6 +46,10 @@ def create_model( `pref_type`. grid_config: Continuous-grid point counts. pref_type_grid: Pref-type `DiscreteGrid`. + subjects_batch_size: Per-device chunk size for the simulate-side + per-subject dispatch. `0` (default) keeps a single vmap over + all subjects; `>0` chunks each device's local shard via + `jax.lax.map`. Tune via `grid_config.get_subjects_batch_size(log_level)`. Returns: pylcm Model. @@ -71,4 +76,5 @@ def create_model( fixed_params=fixed_params, derived_categoricals=derived_categoricals, n_subjects=n_subjects, + subjects_batch_size=subjects_batch_size, ) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 98416ce..85fd8d9 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -28,6 +28,7 @@ def create_model( derived_categoricals: Mapping[str, DiscreteGrid], grid_config: GridConfig, pref_type_grid: DiscreteGrid, + subjects_batch_size: int = 0, ) -> Model: """Create the baseline structural retirement model. @@ -52,6 +53,10 @@ def create_model( pref_type_grid: Pref-type `DiscreteGrid`. Pass `DiscreteGrid(PrefType)` for the production 3-type layout, or a compact variant (e.g. `DiscreteGrid(BenchmarkPrefType)`). + subjects_batch_size: Per-device chunk size for the simulate-side + per-subject dispatch. `0` (default) keeps a single vmap over + all subjects; `>0` chunks each device's local shard via + `jax.lax.map`. Tune via `grid_config.get_subjects_batch_size(log_level)`. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -78,4 +83,5 @@ def create_model( fixed_params=fixed_params, derived_categoricals=derived_categoricals, n_subjects=n_subjects, + subjects_batch_size=subjects_batch_size, ) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 101ef2d..955a5cb 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -1,7 +1,8 @@ """Configuration for the aca_model package.""" -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path +from types import MappingProxyType import plotly.io as pio from pytask import DataCatalog @@ -39,6 +40,22 @@ class GridConfig: # BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. n_assets_batch_size: int = 1 n_aime_batch_size: int = 1 + # Per-device chunk size for the simulate-side per-subject dispatch, + # keyed by `log_level`. Empty → 0 (no chunking) for every level. + # `log_level="off"` skips `validate_V` and its forced host-sync, which + # lets XLA pipeline across periods and reuse scratch — affordable + # chunk size grows. Use `get_subjects_batch_size(log_level)`. + subjects_batch_size_by_log_level: MappingProxyType[str, int] = field( + default_factory=lambda: MappingProxyType({}) + ) + + def get_subjects_batch_size(self, log_level: str) -> int: + """Return the per-device simulate chunk size for `log_level`. + + Returns 0 (no chunking) when this `GridConfig` defines no entry for + the given log level. + """ + return self.subjects_batch_size_by_log_level.get(log_level, 0) MODEL_CONFIG = ModelConfig() From 5e099a0585c2616a3585d0347f7d5ffc00154810 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 23 May 2026 16:49:58 +0200 Subject: [PATCH 04/21] CI: bump pylcm pin to feat/distributed-V-arrays (PR #364) aca-model now passes `subjects_batch_size` to `Model(...)` (see ec29319), which is a new field introduced on the `feat/distributed-V-arrays` branch (PR #364) on top of `refactor/phase-2-api-reorganisation` (PR #361). The CI pin still pointed at #361, so the `pip install` pulled a pylcm without `subjects_batch_size`, and every Model-construction test raised `TypeError: Model.__init__() got an unexpected keyword argument`. Re-point the CI pin to the PR #364 branch. Will move to `main` once #361 and #364 land. Co-Authored-By: Claude Opus 4.7 --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e57ac0b..5c78f7c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -30,10 +30,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm (pinned to the phase-2 branch until it merges to main) + - name: Install pylcm (pinned to feat/distributed-V-arrays / PR #364 until it merges to main) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@refactor/phase-2-api-reorganisation" + git+https://github.com/OpenSourceEconomics/pylcm.git@feat/distributed-V-arrays" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest From a3bc60a3736424b2580213f728cd76b02bb317a2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 23 May 2026 18:29:56 +0200 Subject: [PATCH 05/21] =?UTF-8?q?Drop=20estimagic=20dep=20=E2=80=94=20opti?= =?UTF-8?q?magic=20is=20the=20new=20name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7ca2558..9df7876 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ dependencies = [ "beartype", "cloudpickle", "dags", - "estimagic", "jax>=0.9", "jaxtyping", "numpy>=2.2", From d982b90dc77568d36a472dcd389b11c5f7b17dad Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 23 May 2026 19:14:44 +0200 Subject: [PATCH 06/21] Add GridConfig.n_pref_type_batch_size for pref-type splay MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Solve's per-period `max_Q_over_a` integrand spans the full state grid by default; on A100 even with the assets axis distributed across 4 GPUs the working set runs against the 80 GB device limit. Splaying the pref-type axis with a Python loop (`batch_size=1`) shrinks the per-kernel allocation by `n_pref_types`. Default stays `0` (single fused kernel) — the production GridConfig overrides opt into `1` when the unsplayed kernel doesn't fit. --- src/aca_model/config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 955a5cb..9e7c77f 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -40,6 +40,13 @@ class GridConfig: # BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. n_assets_batch_size: int = 1 n_aime_batch_size: int = 1 + # `batch_size` on the `pref_type` discrete grid: chunked vmap stride + # for the pref-type axis during solve. `1` (one pref-type per Python + # dispatch) shrinks the per-period Q intermediate by `n_pref_types` + # at the cost of an outer Python loop; `0` lets a single kernel span + # all pref-types. Defaults to `0` — the production overrides set it + # to `1` on hardware where the unsplayed kernel doesn't fit. + n_pref_type_batch_size: int = 0 # Per-device chunk size for the simulate-side per-subject dispatch, # keyed by `log_level`. Empty → 0 (no chunking) for every level. # `log_level="off"` skips `validate_V` and its forced host-sync, which From 7a6698f31062cdc81911fdffde7bf40a830f30da Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 23 May 2026 21:56:24 +0200 Subject: [PATCH 07/21] =?UTF-8?q?GridConfig:=20assets=5Fbatch=5Fsize=20def?= =?UTF-8?q?ault=201=20=E2=86=92=200=20(distributed=20axis)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The assets axis is hardcoded distributed=True in regimes; pylcm's grid-init guard rejects batch_size > 0 + distributed. The default has to match that constraint or every fresh GridConfig() raises GridInitializationError at construction. --- src/aca_model/config.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 9e7c77f..e240841 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -35,10 +35,12 @@ class GridConfig: n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 # `batch_size` on the assets / AIME grids: chunked vmap stride for the - # outer state loop. Both partition the per-period Q intermediate so it - # fits in V100 16 GB once we splay across `pref_type`. Set to 0 in - # BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 1 + # outer state loop. The assets axis is hardcoded `distributed=True` + # in regimes, so `n_assets_batch_size` must stay `0` — `>0 + distributed` + # is rejected by pylcm's grid-init guard. `n_aime_batch_size` is free + # to splay; `1` shrinks the per-period Q intermediate by 12x on hosts + # where the unsplayed kernel doesn't fit. + n_assets_batch_size: int = 0 n_aime_batch_size: int = 1 # `batch_size` on the `pref_type` discrete grid: chunked vmap stride # for the pref-type axis during solve. `1` (one pref-type per Python From 63f0b9b395aca196017b7b0394df2280f9dbaa31 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 24 May 2026 10:58:10 +0200 Subject: [PATCH 08/21] =?UTF-8?q?Comments:=20aca-estimation=20=E2=86=92=20?= =?UTF-8?q?aca-slurm=20(downstream=20package=20rename)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/aca_model/benchmark.py | 2 +- tests/helpers/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index a9ea000..8327cd8 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -12,7 +12,7 @@ Parameters (`fixed_params` + `params`) are a committed snapshot at `src/aca_model/_benchmark_data/benchmark_params.pkl`, generated by `scripts/regen_benchmark_params.py` against the current aca-data + -aca-estimation + aca-model code. Pref-type-indexed Series in `params` +aca-slurm + aca-model code. Pref-type-indexed Series in `params` are pre-truncated to two rows so the snapshot loads with no further reshaping; regenerate after any change that affects `fixed_params` shape (regime DAGs, aca-data outputs, key renames). diff --git a/tests/helpers/model.py b/tests/helpers/model.py index be778b4..55571e9 100644 --- a/tests/helpers/model.py +++ b/tests/helpers/model.py @@ -2,7 +2,7 @@ Used by tests that need a structurally faithful model without spelling out fixed_params, wage_params, and a pref-type grid at every call site. -Production callers (aca-estimation, scripts) assemble these explicitly. +Production callers (aca-slurm, scripts) assemble these explicitly. """ from lcm import DiscreteGrid, Model From 4138d41877976ad231a029b67449b01cecec8fc8 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 24 May 2026 13:26:54 +0200 Subject: [PATCH 09/21] Drop dead-regime distributed=False workaround on assets grid. The pylcm sharding-consistency validator now requires every state to carry the same `distributed` flag in every regime that declares it. Restoring `assets` to the shared grid lets the dead regime's V-array sharding match the alive regimes that transition into it. Also drops the test that codified the workaround. Co-Authored-By: Claude Opus 4.7 --- src/aca_model/baseline/regimes/_common.py | 2 +- tests/test_model_creation.py | 12 ------------ 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index f4b68bd..b783ecb 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -449,7 +449,7 @@ def build_dead_regime(grids: Grids) -> Regime: "utility_scale_factor": preferences.utility_scale_factor, }, states={ - "assets": grids.assets.replace(distributed=False), + "assets": grids.assets, "pref_type": grids.pref_type, }, active=lambda _age: True, diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index f948702..d19758c 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -58,18 +58,6 @@ def test_dead_regime_is_terminal() -> None: assert model.user_regimes["dead"].terminal -def test_dead_regime_assets_grid_not_distributed() -> None: - """The terminal `dead` regime declares its `assets` grid non-distributed. - - `dead` carries only a tiny `[pref_type, assets]` value function; - sharding its assets axis across devices buys nothing and makes its - V-array sharding disagree with what the multi-GPU solve produces for - the regimes that transition into it. - """ - dead = build_regime("dead") - assert dead.states["assets"].distributed is False - - def test_non_terminal_regimes_not_terminal() -> None: model = make_baseline_model(n_subjects=1) for name in REGIME_SPECS: From eea782a80a4db400a0d19fd0803da78bd90b7f70 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 24 May 2026 14:26:33 +0200 Subject: [PATCH 10/21] Use the actual scalar param key in the borrowing-constraint benchmark test. `consumption_dollars_floor` is a DAG function output; the scalar key in `params` is `consumption_equiv_floor`. Looking up the function name raised `KeyError` before the constraint check ran. Co-Authored-By: Claude Opus 4.7 --- tests/test_benchmark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 61b6f8b..5d7ddbd 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -72,8 +72,8 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: df = result.to_dataframe(additional_targets=["cash_on_hand", "equivalence_scale"]) alive = df.loc[df["regime_name"] != "dead"].copy() - consumption_dollars_floor = float(params["consumption_dollars_floor"]) - floor = consumption_dollars_floor * alive["equivalence_scale"].to_numpy() + consumption_equiv_floor = float(params["consumption_equiv_floor"]) + floor = consumption_equiv_floor * alive["equivalence_scale"].to_numpy() rhs = np.maximum(alive["cash_on_hand"].to_numpy(), floor) slack = rhs - alive["consumption_dollars"].to_numpy() assert (slack >= 0).all(), ( From 2238869019e247fa257344ecd7ef0c12ba202e27 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 24 May 2026 15:15:04 +0200 Subject: [PATCH 11/21] Regenerate benchmark_params.pkl after aca-data cross + pre-65 health fixes. Picks up the validated cross-grid (`health_trans_cross`) and same-grid pre-65 (`health_trans_pre65`) transition matrices that now carry valid probability rows at every source-active age (51-64). Co-Authored-By: Claude Opus 4.7 --- .../_benchmark_data/benchmark_params.pkl | Bin 68025 -> 68534 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index 15053c68abd1d7813b2188faf491cb90243bcdce..4b492bbb72c2a5a5aeaed85a72001f938d2184b3 100644 GIT binary patch delta 15303 zcmdU02V7Iv_sA9f8~s=n8IkaiAwyb4A}T5h3f>@9YefaM zh@i&y=PZZ|t*csb*1d7z!dlV)eR&BatRRk0|M`4;H+OySIp2HEx#!*BE~Nciwzl?c zRu-{MyNZqMIFpckGC;D0|Xj1^a+WBML=_FH}R*)~R`A;DnOAdq}C!raER zA%HuE-H8h_Tch2C)-|CQ3e-2l;yyEOW$3I1&=J{~Gz>7<#0^UjYf?`FPy-mys05Sv zT`_=^Uk_lc8o-1`0VL*KbS?!>l!|>_mx~FMX)sAXkdH|}P5UbK>zB`@Tq%KWqR z_%}D@C1Y$AU1zcAd3y}xWm$qXRf-sc_SsTjS_&YauWrl+K{0{B;i9YIes8nzWrQxRl$mc6GMS<3q>KiDROzxl)iUsX@F8f0==!jri!Q#te z##w>qkPSKm+ny6VOF24v|E3FqnR7jG%OwGr1|B%u{%?VNGJdpR{*X@vkk41D_0GCg z18T{|Q`lvfD}Ta?sf%!(V_G{;*02Z#OLeVS%ytv_3lL7Ys9&+nIbW^-%0u zmn%0E^IyWC8#DONcQCw1?zBr;(3Rs|=zC+v

k+iV-`9*hwexKLtQOU#ZqA8yA~- z1cTw0EcL@t&6-TZ6&xkWCAh@qiG#i;74)%~wSJ(Ge0JSnO=j1dq!@}LJRYn7>8Iut zWKPYb-inpvlRW-*Vkc+G1XkMq1|7Di*EwDc;sU=H#Rde%1x$@fpgjqTs!ajl2*lTg znfsN804iza*rblX+ghO5!ZyCCKAxKbX@O7Aqf5pty1c%*;kvpe3iUA_bx?ZBr>wjV zxUBWC#5;}Fm3xw7Yl}@y4TW)CK$()4SB1hO8m}u*WN&OG1XGWRI^|gkO_>Zv>O{P& z=J*R~K|L(-Llh?s!ka!ZHFkkCNDQP6Zi}CfPwzn7ihEqeAtW2O)gENdO!W6gD|~c1 zGReNksb3vVYvZn%Y92NXKmG+YK$TKP=(6*OZJs@bI4#aBbX?8 ziJein+SKo{=6r>u^JQL6Y$g$bJ_sDpvuanAVdsvDoI{vvnP@+1%g@I%<@9X9U&uo0 zQ6%$oS))Z9l|5;ZipWIz_BdCt;)MLNwlzXY3Bo2Qcnnk2Y#lxtahdKkz$)&iF za-d9HV#>wlVoQoiYjPECz$8;Ba*}I%a_Alkx=zpk=$`$`9uavEf2Cl3UHi#WTvifH z@+f-J&O$cUZ8?@xE^fP{c#Mi%BC;Ge_>t%7sd5PvdZ11Me*bjbUc1HDs9KnKs*r*Z z5(}_m@qzI;sS6TElk=#~7&3(_RCi)NN)>eD)u@ij1Z_>S@gQLhIh5v@e@6x9L_Qh} z!|95HDKkY%)@*{U7G~HUodO5qAgTn@R+KtIwvuJ6xeuGbF@wpJV%=MNTf3DQ6)*URliR5? z5ll7pz4d_#ro=I$pkA3f~XUS&h=d=BE>2%L)0^!|I>N-?BxfpSOc5M$#@Zn{o*tr=v_{%O_`25(1v0pb`^?my!}v6gX?upk^)B01Gh` zY7GmUw2)DvA|4z;mxb+{mZ~F>h)R%ihZGDV4-E_P7c!N4EW!^ea?Y2>%!Y;j2!=ky z_tfr~i!o~nZOkUsSu4kLK0j+GDbyO)o>Qz@TOFbrW5~l)q3ZTW(CG!$oxiIBblB5D@5u9w2;V*Y>;gYO{@>KPYtS6BUFA^XJP~L%WNl$ z&Kyo?Ny)mh{iOzYQ6u2_@a|39jcqGMj7`&t3ycXB#RZ7NVx{t9LoL=iH6zwON3D}A zWhWw9XdjAL{~W7We|_KPtdD2bJMk*i7-OO2G@B-ZmT{aWzWAgfx;89l^LhHz=-%-9 zE#}FUr5cdDHSpw<*jIyH@kvGCY*@}_@g!@VG_qwUV3TXNn4=Iij%I1#C~=-E^tmS2Z!u3fYCP@Iz|*g3IOxglr!^2iuc!gtU}*})H8K)KWIA%>w^Xe;s$s>AqbaP6&F@1v$+Gg55lV^EOAQA! z;w&w2u6>=C{a&z&$qeD(+b{a^B?6<#Uvs1J7jha}T7)lOWZQ}^62u~ue(1W}BHS)K z-q;)eBx+V1^_VM@Zr{S5T>p(*;w+~T<4)1M4tdKq28^@OXj0eAaet=b=i@(8>5Ia6r6hIlt>pK+^hKd;&GAw# z`lhgExQ16$>;$Q!I*rtok6g#vORSIdAoZx4 zclP;t6bz}4YpJ9Gev0c`C}~Rd`e@^uVx?`G(0`jiTH_wlhI%BSlaDP?^k0?=t=8=^ z>Hcu{oqAf?qR6wR=CES_neltVgbisX-AVl^!W4yOe>k4LGub;;q1isqAzwl>Y${D@ zR(ak6$<8S?Ysr3*Utqr$%J+-3qS&vy4_C{T(l)jaj$FZ>i=A9EPiLX%i~J@wgoMw( zAqvqm>WW@@-s3;%6m>|BPqT~*6NLmsMnwb%1dF4hVnqSLfsx4bBJUH#7kvR6p?#WM zP)juOxa9JQE-Ly07NHPS#UfDO%g+BXHtTyF8h=F-o6UB#FF|X4M?(pJD``zjy6>Kn zIjvrZ#s^Pkf4<+Cs@(eFIT)r-#=#PY9QxonBC1|=Y1b=7cwT>2{?qzL9z1nvMU_Ic z+Qdgzbk&TYGaq|pzM-D^I5;-&t7l`0kAu6W!Kyg~dU}`2cz(E-fuB0}Gw8f0Q7K0g ziGwWp=37PHQ!`vc-m@8dDWw@cjq2p>Tj`z#@}A%69#xrZ^PWWX_HL4@_;lV=BTp0J zHQ8B^zz(NrfdJ6Gd;{2j7$9`=Vvlc}?awC$^ibw9BAnI2$6jFz$ zlADuB*mkH)9i4mZ_W!|98WZ2(X**5qbCstAdUl2X1^HGn4NPvD-qgULsF;AT$l$05 zkyQCsWX!0m3F`B#v!)&3g2Gj46~2;llE!K#qck86eXb@UveLKq_(OI+q3B5=pup#R z;$U_FmzaXCJYVY63u6e!H+B%&;a;)}TKd9LG9$cIWTiP1MzvkDwXcK&v-kHlO&y>G zkuz>B4k9_=T(oiA^#0l~+1noqLpb0yF3)kE1q*&0H>&6Ga1M;!@;)`9lm$UQyH7zg zIWS@QD36gUrI50gvv?QgF4KokkFWp-f8z6Y-oS-%q*J@ zT-a}(_&R&Y8(y3kP`3FA2VA!E&br00!Qjz3ro%1{9G__! zobS(ul1r=Fw~gh%#H42#cl)#9Yvkga;R`ON&?%A0DKpK3KQamIv%jdgY^s z@Sw2m4+ zBn7*8Fh(c8+^;(i%>F4i=9Tkc>BIMli5+;5+uQW~4^=!E-1GI~AMJT?!mIs|-|Tcg zXMObN2Bh$iyS)oq{o1hcpNZ^H=D&8PZLk~QjH(MR#sez&FjEl!Qso$FNfu>DgG zHUBA#3*H@_5`J;z!k&-~`y??ATv!rrNX+Si0e0jh2KUAT7Z(4v z3Ak{RJI6oSp9@S|@nhGaT$s>n-i53vE|l3ji*&|wA=O#n8oPiC_WFK~!$Y~SYtiP{ zuU2y*%Vkf-n>a44V3qdNmT_UtyM3bfbGR_N*N24k!(4FeIcUh(OfHOh&_ObB(FHDi z_pU1AH2cr@_qkA&^6cn=LM~jHG2p-}>@BZ#94;5^<3d{E*xU25w_Lhe#?Cs< zg%1{ko}M!1!BFj2Dc@bj;UIPV#(!J!;D}!K@GB3vP(J!nW-k|B6J952!0o=K*GU?1 zyRYeW;a|XQ8}wQ!DagDf#!VZ!pmB# zOZy-i;nM{lo$jdrdoXk4Czg#Zx+Bnwvhds5;gLWp{zb?{9vj*a`PrTXVKrbIMwI_> zvjcx2f5nfwf7HDTx7EE1&c$uDk)Z_j`Ot^BkoAc0QNTJk4;Dqg=V#9)2xIL!VhwEG z{JLUYQxmI_mu<%+tnJIG))?|URj6`z0XO)?2cgZ?K1`CPqxo_>Ykd|X^>tK z5;mFy&irA3e_G&!Q${>#V|8k?aK@g! zLvdC%qsYMMbpjh=PCbk*560P6k}UDm0ybRXmyURnz=1(S%iV{~V?*+XpUNy2a^S;? z)@7%n+3-t`7w4mK+V^Wg>BNo0aldcxL#?woP;4GF(A9+vdw0eLja!1#z(oU2uc=~z z$m>?Yh(r!tc9~q^e~1N92E8whn27Vxs;=$+T!vV1aEH((&jY7_kMD7jhy}Mdhf8xS zah8?e;d^Oc7Ffg-v@2T6h7MCE{(HiK1^wMm|LfV84K4@wb%-@#LB~OBFaLyr-^%7h zY&T_r-r_1wOg9z`UB2YU-&|R6Aid&PSIk@=iH>&OXcjmOd~P~Za#9Nhd;N2G%}N&h z{N3Dbm!31Bh|pmjtY$%u{fqFoZ}HarlfSWDG8-!IjY^w;hXG`cSI|8d4vZ@pw&l!8F0|@hkmz#qAAx`L;&odMbRai68}_VzEhr|R zcI_Rg3zp`2dJ?|#3qkzKxt>0E^&rkaedybxPXxn`eLLy*90Tx7>yf!@>s`UYqD5`a z3^oGs==MPwZ!QaV3>&8FKpI1iU;0Y_v|_Q>1tzfd)U~^|Q~oG^xy01O-Ixyr zi**wpgx@T_-C6vXV3ip>O4#yG*6}CB{43rU9tkaAy3UAkSA2-q#ZLOdr&pb=6A0flzip$VZcDOm=QqL!pJNf%xeFeg@%w9AK&r+0M~TA_W%F@ delta 14891 zcmdU02UwKH)4%sPiXcTmItYkLQ9-UDDl8%#2@&iuDk8|q@hU~&tmtVXDvF5i0g4S2 z1?+li(AX1#Jt#KR*kVj9iG>&=CiC7Di}g~=&OA0^TKM?*xFm6GLUK%eTzFLcq)8EpNm7BIap)ZUr`C9igZ&)r4AcDE zFO3iOU*};Ip!4~%K#a9}QUw@kca&@fQWb0O5eEFh-B+3{oh23RNJ~nV4w0toC&WaA zN5sSm!=n>pqNJh%XOXL@M=&E*Dso>Sbq=P2S?~xvjL3Kf!!nSOGC{zYcTb?sryb#w z91%H5SdWKO114-)yTG73M*?*pQ#DLvn{ZhL z?8p>7$0zkXsOz$7z3CKn8;9tb*Ytm->i-7%TbMi*C5oPj#GvI=_&0rF5)8hj?+$|v z4b%qP7#UFvM(uFABZYUxTAGR)V^D3#*Q^wM72E@l+iWWP)M!1WaXXle2kCWGHXmDV zVc1QLojdk0wr*6KvBoYS%|hdP()g(j4{S!7BvURhC^oIfV5Ay@ux1!Y%{bI`Y-rF? zcGE1XE}%9>@oG$xn_;rYq5zC6)Y_aF*|?1tMHP${c7crJIf0B4Zvo?4serM{KWNU$ zb3u%?KLjc0Yk_VqV3ZsXFkGF3=2VJ<(BX<8^zgS9>FNe&5SbV;exfirDO{Y8pealQ zv#XV)1u2FyRx~R`i}ZSu#6~1W$HYZU3Qvqk))XcwC7DPtAubbir@G%8e{Uaw#tI(~ z*J2;Nz)~O2;pIMS$3F0RU%FBf+{?yBu@bej5)AV0WCTvIa}?Mw$Q8IU41#!X(t>F+r>&>|Brgm3jlWsBr%wE0|kX<3Bf0Qs$ z*l5^R2r6gHVEE>){K6L`Q@bk}%QWpJ6jPsxBZ@47O_>ZvS{SULIXuu7}1q%o29A-6i=pZ1mccg|$8;-DS27#pKWYJ;aUNz|8%x7={lzM^WpsgG*P_EigO zEec`>lL;i;!{5q zQKhn{0il#r7O)FmP_qQ}5uOTEIs?xgX(5{_h-F#Tngy(j!UN_I4?y%(OoiE*q!8H< zodGvFa)1xsDU?!K7>!%?Zc~B6MoENCt-{3`cxeW4c_;kqK)2>au*HTh9hkE;ut)zc zJ>~-&c?d{@2YL~?EQSk-kMre`@qOeZkhGuX&FZ-kSgHVQ%IMT!x&lK!08Bp+HnlKS z5tytjzbB1}bQ!tTb?Xy2e1vBk0Z%q}j*fPxNRTka_D+Vf?a?$vZN%%ATPYjQE!Q(FP~*52 z?~Ju>#+AL*Loq20868$*^lb}_zMkkw72?s@RW=I`qBNv*eYq~FvAfm+yL++x z=B>oeS`jAGb(7&NB1{e!#Ux1^hRZJR5Sgo3%OlKX*Wz1e2+C+(1?shSdU2_F+>e0(3lmPc5~^r!W1Rh0Pv)5~#j(X1P0`nQ0~pxHhx*o}!Y{JkG8 z=CXh}t+SU8FufcXp&Bl;Tf~J#Uze7!QAV}*#b)-x)?xDjrk7)ruZGQm7O**<0ijFw zBBNmpH2P7E(OpeAHZ_+VJjo=S9G5l1=?^tduUg_{gMEP0XxxyKp&lgus(sf8J7hH% zD28XbH^1orZL*i1(>lv=B}%`4PVA_LU7%JAIZOGPwvz10a~0jZ)`2(q!+;l~2HujE zzzf+FtUN`d*&6Z9yktKe-#)L^9#=1q_TR;0Pwb}2H*_+9@UKV*#hp|U->nAi~ zXxEJZn%eS8P?P$~Y&V}&W_PU4Fj8I*a!_O8s@+0hi`v?c>LOdewI&>?81+?S;*?allM|)|?CvbRu58%YiDzw@Rl8Vj8uq>usi3KUTa3gJ?ZGDAU`S3@C3^IN5r zAld%!Z!6?;eP{V#-Jq&8l(4KlMs2`Zhx<_{aRpQnb7e0NCu>R8K?jw-O&yIttWwIt z1eB7_;D@E>-a|kCay9)fkvB!%ic=Txxm4RXD=~PA8O~1xoV8uy+QB`*9qF!d}1mS_UR|`(0gr(}=qs9Jp z3tO0~_!6q4OV6gk)!<^9q82Bvk*0p7NM1oz)Z%2P&mGrL6ys{Ux3Kt*I)GfG1=MZ| zN+ZQMu7XlmX4T@0ar=s?trH)UnwvFYUTfRK$(v&B#r4;|>6)AQrC|LxU=jj^hw7TH zVQt5AuKSoNiiEIKEMO`?%kK;P3*2Ah^OqbYwIm%mX=;(Yfo@0Md{1aCDLGKqf!3AN ziZjo^-&tCZyw#3tFYmct^B}b$&YjjLZ+ouAzT}NH&)oAJs2dIWsN~9 zK04FpFoL!qmzj9pWe=rB>%TGPKRoa=Xyj(C0YsGy3bj@qyD;jTnBd0Ll9o&$e@eAZ zJC^@E&M!@=)*+A2K8jjk7|p0cOpZwu;we8`wK13*^+OfR%WpP>xy}2b3hy1aYVb~< z{cHEsWs@7*q?QdZ4H^Ejps{E^WgLW7h){}6s!aILqHFtKD5Lp?3O4!smzz|5yhVEO zlLhjX$w1V>WiAexqOD4kWY=$Zq!f8lr<%;9exBs$)W|2vrqd+&@ZInWrds#983(q0 zK=PxF4>IA(dzh)nl2U49X)a6JcCT8QCDnFsAxkO(Pj5U++NPj1k|lkkpwyLFl_jO% zt`BCnPJDTmluww~W=Yq{n<7g}dG{)?%OMHik8z*I*v=+_Rxm))!srl#Po1 zF}$*uTy#~kL#&)OP@0QE^7?vYR|!-b9Cizc#smKtMwRB+Yf%%h&WoiIFBl#?s;Mo; z8mM$HxCrGHAhM+hoU_Zx2?+k~?=CFOHD#k|0Z)%)-S9-?j#^A!GlGpQWFuu%)HHqUq;9 z-MZo!4^_U6yB@8>L&b-E{-!VS&%k^W=#bJ`yu?b*%s7p)BFT@nQ4$XoNwJ-D}MGXwIzUho{!@(T?z) z*88}8bj>TWhtDBCaxYjF}E z9mqQ@3~~2DM}Jw+;rL8GS{OLhsLyyW)HymLv`-2j6~AT_@bbM-LD`eoYUsc4$_u?I zkGv3(Q#+E6=mkf*P8`r1?T~_G0eocl>Z`fGZ10Uy9}O0SR_QorFp ze(%sgBHY`d8}9Lru%Mih|X0fIt}2W z>}jY{`iP4vj#c|Eh~c84V}}13=DxvkwbaNcI(@u6<>4FtYzid zNppGVUN4QTXScZMap@qCk%v~D;kY9o6)fp|`R-RdRQ=ndqrKVuW*krQaXjyx9Z&LcJnx+yP11zxzB7Y& z9WSMWz|vKlk)=Geuc{q7GYi|nkE=~TA&0>2z(M{|XBRxsyWvv*x7GWg zO|i|Yep{UnZK|knEax+d>Ze@^jqvRbUwO;Ld5(I@S^E9kS;%IdwGnuczpO2kwKjbS z-ry(IA4K%qYN9_Uq63~+NNXZ}mE7v$S(&qlqWqPrxW2n8S^1Qy#WQ_g(%l0FC^F~k zWeI;h@s)m1MGSQ7u7Q1y1Bghe^#jA2nuIS-+GjzE*@mh(NlS{fvi z|8pgbi;ABbD~wB)3efg#KLrg!ba7- z(l=DY8#1%+MYh8s;~VSyot7SKNN&a^7oRTUpxIaQ+z-v=piyhJr$%4mAf7beb<=td z3Lg(!ni&`EEm~Od$#D*v*gYaPeK;3AbM^GLzr#UW-~14981@lcQ;Bqa)>{s$IJ%U+ zTE<1@0fs5htdk8Xx8 z$%%sjqi?+mXqU}L`3Ls@b^S6A-R@y`GpU4+QqPv{&A-n>jGia5f4;&;uZ^Ewe)X1z z3XPV2ce;v?9u0>ivkM>fDBtUzr{{&vdT-9PAH_#5e{Fgo8|sBlSsmPQsvTc~o=Ntb z^(gQ{H_i^VO_cG_EA5e6ru^fD?0Wg>-vZ5Fc>K2Nb$D-d&|z3{Wfd3st@`Ok*E7A* z#gsADu8X)R^_Pv2e}mJ&okA~zgCTv;sGY;i zpE`0-$FQ3=+kftZ_Kh6;Rm^ICHX1&q+AnNjU*vsgNVbCk8#S@v{73tZkw7i3jMK$_ Q)s~%`4v>n#A{O`jAKIc+BLDyZ From abc39cab2e8775fac27eb63b6ff3814893590e2e Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sun, 24 May 2026 15:45:09 +0200 Subject: [PATCH 12/21] Regen benchmark_params.pkl: revert health-prob extensions; validator filter handles unreachable per-target. Co-Authored-By: Claude Opus 4.7 --- .../_benchmark_data/benchmark_params.pkl | Bin 68534 -> 68534 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index 4b492bbb72c2a5a5aeaed85a72001f938d2184b3..802ea586598d595bc32f138b9c6c30ebc8146ba2 100644 GIT binary patch delta 131 zcmV-}0DS+pmjt$#1hClum-7MvGzpcV3jhEB0001!T17aMP4q04bpb$^iUR?Nw><#? zwgHpa|1^`JA`6$z0RdE%wE>8i5(EKCm!Sj!CYS020UMXc1OX!%4s>a4WnX1(WN&wE lWo~qoM?kKv&lAvXgUOpbFknd-}!GrfzMv17X^f Date: Sun, 24 May 2026 18:15:46 +0200 Subject: [PATCH 13/21] benchmark_params.pkl: restore max_consumption_dollars=300000.0 in fixed_params. The pkl key is consumed by `get_benchmark_params` via `fixed_params.pop("max_consumption_dollars")` (benchmark.py:119) and threaded into `inject_consumption_dollars_points`. Earlier regenerations of the pkl dropped this entry, which broke both `tests/test_model_creation.py` and pylcm's `benchmark-pr` GPU run (the latter installs aca-model from this same pkl). Value (300000.0) matches `MAX_CONSUMPTION_DOLLARS` in `aca-data/src/aca_data/task_environment_constants.py:23`. Co-Authored-By: Claude Opus 4.7 --- .../_benchmark_data/benchmark_params.pkl | Bin 68534 -> 68025 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index 802ea586598d595bc32f138b9c6c30ebc8146ba2..24c513835bfb15a47b358ab88b5677f38f534caf 100644 GIT binary patch delta 2213 zcmb_beN0nV6u+kw+EOS*+4==D7O7<~gc$6ku9PE9mpIH*rue)mlFH7HowA2c04E(#4rv`lpA9+BkT$cOeCM%~ zF3Bk8{hcD08cu8=6&%F8rN~;_(_f=Sop?a7oYvj89sab>Y{N>HKsd?7OAk2 zwB579qH#<%u_|=&W8Tc$wsb2O)Mj0k@VQlO@3%giKWOcE=Zf`VB+;;O>bvt=hh4s@prr}D>2f3D>U_PVoIsSYl@?AF%E*i%MvktRk zUJ4JJnLT4;A2B(0B^H`P!o!$@d1!PaUXDrhd6a$oGo39J8JJ{Fo5e8Ph}us13Vrv- z;I`=xIS3hSqLF$+MtlFWm=wFppeKX!8CEj-bb!U?*et!CVzc&s`bO3QD^15H_67@o z0wL1+A1f5HcqI-zuZNX5dD#G|(&XtsAi}65#92}x){7RWQ+Q4KeW)H*uoSmVPIn^K z(vb78yQWrrRTLhlXhFxeZBo-OadSx7y^t6)qLVJ34Fc_BEYqK3>YNF*yBB(6MvS2L z!%8pwssZ(f^DybH^Kwx8f<0gcp zQOdkGajZxmVxmW3we;0LYl37Z6*iW{sN$&0nJuhVVOL=kCfzE4G~9Ko(l1Nl%c6kX zr4c#?pSzgrFx#66o3X*02sWwLdr1a7?wL+}K=SGU@~xEIXcWqrvsw_CanVb^l|Ju( zTqA78tUEv9{$(0;zL?Ou9vUzseg(+T5x-p4qe91_sI{_>MBy;aa~Fsf;RrnlU(l_j zRxG9*1@4+1Vli#~DD}Cv)Bngpy)9_GrypX2tJvjw=3n!!yP*I8 delta 2663 zcmbtUeN0nV6o0oAT3`Z|CbX{~U=W55r3^&Wu=msq4 zzxzAqp4Y38f7#6ES9+Sj1^$ZyIPDWGGioIT!+I+G*%DG;tl<4U@&TFg{2Rxw6OAOS z5+q3^ceYjRJ6-$N4a~mmHtZI}uL*HDISjc)vRL`UmwONfjA@(g>1oTm=Evc8B{+x4*EI_EWd#HZH9RHV%9Vj(%a^Y~ zDp3AXUF5zZwQe(Kf$N7YL|!`4oAAz#_Y?EUvL+H86a+t)C&Tpm;Ao>(buP=s+4`29yWEC@ zbx*S428AkG#qGQ|NFrqA+D7r<$bvR3RuC#%7B)z*cMuaZY9E~iZKhcfSqCTN;TLbe zF5ra;;pr_BupZVvkb-!owsi#fD}@-Q+nY(3(-Ary^)f@Zm4mf`$Z6E_0n$bnp)Fm# zv9O3vqbVG1P%`#$h}m>EWG<5Tf03}IGl;3V=K%twY^E#6nGP8UsG09=>BNDrODr?h zw<_dgWy>THcU&o$<2 zSlcRRIv(@`{46@s*_?jXI-tu*!BP&w58I-c*j7}cnoEDPK*whkO( zipPS0(pis%t<1+8!;(-NxrN$W_nQnToh3dg=FIT#u0>nm__*AEJ`X+tz?OMDBDb^V zWgszgIt)z7D!p6Oh-SU5Q? z_Q%~4^w?WOXQS*) zV__AYjcUA#e1^BtH;{$AXeUDV8g&-=oRu}ggvSY>3zj|B+G)cVM+@`*6$ehPiYvVT zr5~AXZa9%=9Vcj=tH8v?G0hh^;3ja=+pOdPT-w2U3#d*OT)Wg|`V}NDUAp-Q03x27 z+-3H2m5*ybgUt#akS(mXAr%z^z%g@<2Z#{9i^jH`rbzQC53qBwb5uPgY1rrq+LkeY j+Y{Zy?qRJ1lnM3JkYzxBmf5;IYQaXMsGWvvgJI);{IwC2 From 46ceb032b7689bde8ccb874c115526c06ca3610b Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 25 May 2026 10:28:30 +0200 Subject: [PATCH 14/21] Drop private `_lcm.*` imports in favour of public `lcm.*` - `_common.py`: ContinuousGrid annotations replaced with concrete PiecewiseLinSpacedGrid (aime) and IrregSpacedGrid (consumption_dollars), matching the actual built types. - test_initial_conditions_extreme_assets: import validate_initial_conditions through `lcm.model` instead of `_lcm.simulation.initial_conditions`. --- src/aca_model/baseline/regimes/_common.py | 7 +++---- tests/test_initial_conditions_extreme_assets.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index b783ecb..c9cfaa8 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -11,7 +11,6 @@ import jax.numpy as jnp import numpy as np -from _lcm.grids.continuous import ContinuousGrid from lcm import ( DiscreteGrid, IrregSpacedGrid, @@ -191,8 +190,8 @@ class RegimeSpec(TypedDict): @dataclass(frozen=True) class Grids: assets: LinSpacedGrid - aime: ContinuousGrid - consumption_dollars: ContinuousGrid + aime: PiecewiseLinSpacedGrid + consumption_dollars: IrregSpacedGrid wage_res: Any hcc_persistent: Any hcc_transitory: Any @@ -298,7 +297,7 @@ def get_hcc_persistent_grid_points(*, grid_config: GridConfig) -> FloatND: def _build_aime_grid( *, grid_config: GridConfig, fixed_params: UserParams -) -> ContinuousGrid: +) -> PiecewiseLinSpacedGrid: """Return the AIME grid. The grid is piecewise-linspaced with breakpoints at the PIA bends diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 17326fd..861c487 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -8,7 +8,7 @@ """ import jax.numpy as jnp -from _lcm.simulation.initial_conditions import validate_initial_conditions +from lcm.model import validate_initial_conditions from lcm import DiscreteGrid from aca_model.agent.assets_and_income import borrowing_constraint From 77ba8e44d4ba650c7cbc2c7f155dc73161684cdd Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 25 May 2026 12:24:50 +0200 Subject: [PATCH 15/21] Drop unused plotly + DataCatalog imports from aca_model.config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `import plotly.io` + `pio.templates.default = ...` and the unused module-level `DataCatalog()` instance contributed ~250ms of import time to every task collection that touched aca_model. Plotting defaults belong in aca-post, where the plotting actually happens. Brings `import aca_model.baseline.regimes` from ~1.6s to ~1.07s (the remaining cost is dominated by the `import lcm` beartype claw). 🤖 Generated with Claude Code Co-Authored-By: Claude Opus 4.7 --- src/aca_model/config.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index e240841..5850a63 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -4,17 +4,10 @@ from pathlib import Path from types import MappingProxyType -import plotly.io as pio -from pytask import DataCatalog - SRC = Path(__file__).parent.resolve() ROOT = SRC.parents[1] BLD = ROOT / "bld" -data_catalog = DataCatalog() - -pio.templates.default = "plotly_dark+presentation" - @dataclass(frozen=True) class ModelConfig: From 032792146b59820c3804d06b434b5960314f7865 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 25 May 2026 16:18:07 +0200 Subject: [PATCH 16/21] Add n_wage_res_batch_size: splay wage_res shock axis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires `RouwenhorstAR1Process.batch_size` through `GridConfig` so the wage-residual stochastic productmap can be split with an inner Python loop. At `n_wage_res_batch_size=1` the per-target Q_and_F intermediate shrinks by `n_wage_res_gridpoints` (5), bringing the kernel under 80 GB for the ACA-overlay nongroup_nomc_* regimes where the unsplayed kernel hit 144 GB. `n_pref_type_batch_size` remains a no-op pending separate splay wiring. 🤖 Generated with Claude Code Co-Authored-By: Claude Opus 4.7 --- src/aca_model/baseline/regimes/_common.py | 1 + src/aca_model/config.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index c9cfaa8..896d87c 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -242,6 +242,7 @@ def build_grids( rho=_WAGE_RHO, sigma=(1.0 - _WAGE_RHO**2) ** 0.5, mu=0.0, + batch_size=grid_config.n_wage_res_batch_size, ) hcc_persistent = get_hcc_persistent_shock(grid_config=grid_config) hcc_transitory = NormalIIDProcess( diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 5850a63..00063dd 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -42,6 +42,14 @@ class GridConfig: # all pref-types. Defaults to `0` — the production overrides set it # to `1` on hardware where the unsplayed kernel doesn't fit. n_pref_type_batch_size: int = 0 + # `batch_size` on the `wage_res` stochastic shock process: chunked + # productmap stride along the wage-residual stoch axis inside Q_and_F. + # `1` shrinks the per-target Q intermediate by `n_wage_res_gridpoints` + # at the cost of an inner Python loop; `0` lets the productmap span + # the full axis. Defaults to `0` — production overrides set it to `1` + # on hardware where the ACA-overlay per-cell DAG blows the kernel's + # compile-time working set past device HBM. + n_wage_res_batch_size: int = 0 # Per-device chunk size for the simulate-side per-subject dispatch, # keyed by `log_level`. Empty → 0 (no chunking) for every level. # `log_level="off"` skips `validate_V` and its forced host-sync, which From 648686a7db4820790c1ee00b670fcab8d6cced88 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 25 May 2026 17:47:37 +0200 Subject: [PATCH 17/21] GridConfig: add assets_distributed + pref_type_distributed flags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Threads the existing pylcm `distributed=True` semantics out of hardcoded values into GridConfig so production runs can swap which axis is sharded. Defaults preserve the assets-sharded layout. The production override pairs `assets_distributed=False + pref_type_distributed=True` for the 3-device pref_type-sharded mesh: zero cross-shard collectives (pref_type is a fixed state with no transitions) and 1/n_pref_types per-device V_arr reduction. 🤖 Generated with Claude Code Co-Authored-By: Claude Opus 4.7 --- src/aca_model/baseline/regimes/_common.py | 2 +- src/aca_model/config.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 896d87c..1d4b391 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -262,7 +262,7 @@ def build_grids( stop=500_000.0, n_points=grid_config.n_assets_gridpoints, batch_size=grid_config.n_assets_batch_size, - distributed=True, + distributed=grid_config.assets_distributed, ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), consumption_dollars=IrregSpacedGrid( diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 00063dd..9e310e3 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -28,13 +28,21 @@ class GridConfig: n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 # `batch_size` on the assets / AIME grids: chunked vmap stride for the - # outer state loop. The assets axis is hardcoded `distributed=True` - # in regimes, so `n_assets_batch_size` must stay `0` — `>0 + distributed` - # is rejected by pylcm's grid-init guard. `n_aime_batch_size` is free - # to splay; `1` shrinks the per-period Q intermediate by 12x on hosts - # where the unsplayed kernel doesn't fit. + # outer state loop. `n_assets_batch_size > 0` is rejected by pylcm's + # grid-init guard when `assets_distributed=True` (pylcm forbids + # batching a distributed axis); set both consistently. `n_aime_batch_size` + # is independent; `1` shrinks the per-period Q intermediate by 12x on + # hosts where the unsplayed kernel doesn't fit. n_assets_batch_size: int = 0 n_aime_batch_size: int = 1 + # Sharding flags for the assets and pref_type grids: pylcm distributes + # the grid across devices when `distributed=True`. Only one strategy + # can be active per run (pylcm's multi-grid sharding requires + # `prod(grid_sizes) == n_devices`, so combining is infeasible at the + # production state-grid sizes). The default keeps the legacy assets- + # sharded layout; the production override flips to pref_type-sharded. + assets_distributed: bool = True + pref_type_distributed: bool = False # `batch_size` on the `pref_type` discrete grid: chunked vmap stride # for the pref-type axis during solve. `1` (one pref-type per Python # dispatch) shrinks the per-period Q intermediate by `n_pref_types` From 522b22cb08105e0dc11a328364515b44dba9404c Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Mon, 25 May 2026 18:05:16 +0200 Subject: [PATCH 18/21] GridConfig: add per-discrete-state batch_size flags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Plumbs `n_health_batch_size`, `n_spousal_income_batch_size`, `n_lagged_labor_supply_batch_size`, `n_claimed_ss_batch_size` through GridConfig and into `build_states` via `Grids.grid_config`. The flags let production runs splay every non-sharded discrete state with a Python-level outer loop, compounding the per-call Q intermediate reduction by the product of the splayed axes' cardinalities — outer to the continuous block per the `_ordered_state_action_names` sort. 🤖 Generated with Claude Code Co-Authored-By: Claude Opus 4.7 --- src/aca_model/baseline/regimes/_common.py | 22 ++++++++++++++++++---- src/aca_model/config.py | 13 +++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 1d4b391..55bedbb 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -196,6 +196,11 @@ class Grids: hcc_persistent: Any hcc_transitory: Any pref_type: DiscreteGrid + grid_config: GridConfig + """The originating `GridConfig`. Exposed on `Grids` so `build_states` + can read per-axis `batch_size` settings for the discrete states it + constructs inline (health, spousal_income, lagged_labor_supply, + claimed_ss) without changing the `build_states`/`build_regime` API.""" # AIME piecewise grid: number of points per segment between the PIA @@ -272,6 +277,7 @@ def build_grids( hcc_persistent=hcc_persistent, hcc_transitory=hcc_transitory, pref_type=pref_type_grid, + grid_config=grid_config, ) @@ -392,23 +398,31 @@ def make_active_func(spec: RegimeSpec) -> Callable[..., Any]: def build_states(spec: RegimeSpec, grids: Grids) -> dict: """Build the state dict for a non-dead regime.""" can_work = spec["canwork"] == "canwork" + gc = grids.grid_config states: dict = {} states["assets"] = grids.assets states["aime"] = grids.aime states["health"] = DiscreteGrid( - Health if spec["mc"] == "oamc" else HealthWithDisability + Health if spec["mc"] == "oamc" else HealthWithDisability, + batch_size=gc.n_health_batch_size, ) states["hcc_persistent"] = grids.hcc_persistent states["hcc_transitory"] = grids.hcc_transitory - states["spousal_income"] = DiscreteGrid(SpousalIncome) + states["spousal_income"] = DiscreteGrid( + SpousalIncome, batch_size=gc.n_spousal_income_batch_size + ) states["pref_type"] = grids.pref_type if can_work: states["log_ft_wage_res"] = grids.wage_res if can_work and spec["his"] != "tied": - states["lagged_labor_supply"] = DiscreteGrid(LaggedLaborSupply) + states["lagged_labor_supply"] = DiscreteGrid( + LaggedLaborSupply, batch_size=gc.n_lagged_labor_supply_batch_size + ) if spec["ss"] == "choose": - states["claimed_ss"] = DiscreteGrid(ClaimedSS) + states["claimed_ss"] = DiscreteGrid( + ClaimedSS, batch_size=gc.n_claimed_ss_batch_size + ) return states diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 9e310e3..100fb6a 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -43,6 +43,19 @@ class GridConfig: # sharded layout; the production override flips to pref_type-sharded. assets_distributed: bool = True pref_type_distributed: bool = False + # `batch_size` on the inline-constructed discrete state grids + # (health, spousal_income, lagged_labor_supply, claimed_ss). These + # are read in `build_states` via `grids.grid_config`. `1` puts each + # axis in a Python-level outer loop within the discrete-states block + # of the productmap (`_ordered_state_action_names`), shrinking the + # per-call Q intermediate by that axis's cardinality at the cost of + # one extra lax.scan layer. Defaults to `0`; production overrides + # set to `1` to compound the splay across all discretes outside + # the sharded one. + n_health_batch_size: int = 0 + n_spousal_income_batch_size: int = 0 + n_lagged_labor_supply_batch_size: int = 0 + n_claimed_ss_batch_size: int = 0 # `batch_size` on the `pref_type` discrete grid: chunked vmap stride # for the pref-type axis during solve. `1` (one pref-type per Python # dispatch) shrinks the per-period Q intermediate by `n_pref_types` From 09de8ab502f7dc2122333ba735f30ad576a2f140 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Tue, 26 May 2026 12:36:01 +0200 Subject: [PATCH 19/21] Drop assets_distributed from GridConfig Continuous-grid sharding is rejected at pylcm grid init: every interpolation lookup over a sharded continuous axis compiles to an `all-gather` of the full V-array per device. The `assets` axis therefore cannot be sharded. The `assets_distributed` knob and its consumer in `build_grids` are removed; runs select `pref_type_distributed` (the remaining discrete-sharding option) or fall back to single-device splaying. --- src/aca_model/baseline/regimes/_common.py | 1 - src/aca_model/config.py | 37 ++++++++++------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 55bedbb..143fd54 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -267,7 +267,6 @@ def build_grids( stop=500_000.0, n_points=grid_config.n_assets_gridpoints, batch_size=grid_config.n_assets_batch_size, - distributed=grid_config.assets_distributed, ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), consumption_dollars=IrregSpacedGrid( diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 100fb6a..f76df07 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -28,30 +28,25 @@ class GridConfig: n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 # `batch_size` on the assets / AIME grids: chunked vmap stride for the - # outer state loop. `n_assets_batch_size > 0` is rejected by pylcm's - # grid-init guard when `assets_distributed=True` (pylcm forbids - # batching a distributed axis); set both consistently. `n_aime_batch_size` - # is independent; `1` shrinks the per-period Q intermediate by 12x on - # hosts where the unsplayed kernel doesn't fit. + # outer state loop. `n_aime_batch_size=1` shrinks the per-period Q + # intermediate by 12x on hosts where the unsplayed kernel doesn't fit. n_assets_batch_size: int = 0 n_aime_batch_size: int = 1 - # Sharding flags for the assets and pref_type grids: pylcm distributes - # the grid across devices when `distributed=True`. Only one strategy - # can be active per run (pylcm's multi-grid sharding requires - # `prod(grid_sizes) == n_devices`, so combining is infeasible at the - # production state-grid sizes). The default keeps the legacy assets- - # sharded layout; the production override flips to pref_type-sharded. - assets_distributed: bool = True + # Sharding flag for the `pref_type` discrete grid: pylcm distributes + # the grid across devices when `distributed=True`. Sharding is only + # supported on discrete state grids; continuous axes (`assets`, + # `aime`, `wage_res`, `hcc_*`) compile to an all-gather of the full + # V-array per device and are rejected at grid construction. pref_type_distributed: bool = False - # `batch_size` on the inline-constructed discrete state grids - # (health, spousal_income, lagged_labor_supply, claimed_ss). These - # are read in `build_states` via `grids.grid_config`. `1` puts each - # axis in a Python-level outer loop within the discrete-states block - # of the productmap (`_ordered_state_action_names`), shrinking the - # per-call Q intermediate by that axis's cardinality at the cost of - # one extra lax.scan layer. Defaults to `0`; production overrides - # set to `1` to compound the splay across all discretes outside - # the sharded one. + # `batch_size` on the inline-constructed discrete state grids — + # health, spousal_income, lagged_labor_supply, claimed_ss. These + # are read in `build_states` via `grids.grid_config`. Setting any + # of them to `1` puts that axis in a Python-level outer loop within + # the discrete-states block of the productmap + # (`_ordered_state_action_names`), shrinking the per-call Q + # intermediate by that axis's cardinality at the cost of one extra + # lax.scan layer. Defaults to `0`; production overrides set to `1` + # to compound the splay across the unsharded discretes. n_health_batch_size: int = 0 n_spousal_income_batch_size: int = 0 n_lagged_labor_supply_batch_size: int = 0 From 7dbb34488adb49619e2ab6e28d9a855b1bc557f2 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Tue, 26 May 2026 12:56:37 +0200 Subject: [PATCH 20/21] Drop subjects_batch_size from create_model and GridConfig MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `subjects_batch_size` knob was added alongside a pylcm change that has since been reverted; current pylcm `Model.__init__` does not accept it, so every `create_model` call was raising `TypeError` against unmodified pylcm. Drop the kwarg from both `baseline.create_model` and `aca.create_model`; drop the matching `subjects_batch_size_by_log_level` field and `get_subjects_batch_size` helper from `GridConfig`. If per-subject chunking turns out to be needed later, it should re-land together with the corresponding pylcm Model parameter — wired through a single source of truth, not split across two repos. --- src/aca_model/aca/model.py | 6 ------ src/aca_model/baseline/model.py | 6 ------ src/aca_model/config.py | 19 +------------------ 3 files changed, 1 insertion(+), 30 deletions(-) diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 22437dd..6d39ac6 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -25,7 +25,6 @@ def create_model( derived_categoricals: Mapping[str, DiscreteGrid], grid_config: GridConfig, pref_type_grid: DiscreteGrid, - subjects_batch_size: int = 0, ) -> Model: """Create an ACA policy variant model. @@ -46,10 +45,6 @@ def create_model( `pref_type`. grid_config: Continuous-grid point counts. pref_type_grid: Pref-type `DiscreteGrid`. - subjects_batch_size: Per-device chunk size for the simulate-side - per-subject dispatch. `0` (default) keeps a single vmap over - all subjects; `>0` chunks each device's local shard via - `jax.lax.map`. Tune via `grid_config.get_subjects_batch_size(log_level)`. Returns: pylcm Model. @@ -76,5 +71,4 @@ def create_model( fixed_params=fixed_params, derived_categoricals=derived_categoricals, n_subjects=n_subjects, - subjects_batch_size=subjects_batch_size, ) diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 85fd8d9..98416ce 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -28,7 +28,6 @@ def create_model( derived_categoricals: Mapping[str, DiscreteGrid], grid_config: GridConfig, pref_type_grid: DiscreteGrid, - subjects_batch_size: int = 0, ) -> Model: """Create the baseline structural retirement model. @@ -53,10 +52,6 @@ def create_model( pref_type_grid: Pref-type `DiscreteGrid`. Pass `DiscreteGrid(PrefType)` for the production 3-type layout, or a compact variant (e.g. `DiscreteGrid(BenchmarkPrefType)`). - subjects_batch_size: Per-device chunk size for the simulate-side - per-subject dispatch. `0` (default) keeps a single vmap over - all subjects; `>0` chunks each device's local shard via - `jax.lax.map`. Tune via `grid_config.get_subjects_batch_size(log_level)`. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -83,5 +78,4 @@ def create_model( fixed_params=fixed_params, derived_categoricals=derived_categoricals, n_subjects=n_subjects, - subjects_batch_size=subjects_batch_size, ) diff --git a/src/aca_model/config.py b/src/aca_model/config.py index f76df07..a3186ae 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -1,8 +1,7 @@ """Configuration for the aca_model package.""" -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path -from types import MappingProxyType SRC = Path(__file__).parent.resolve() ROOT = SRC.parents[1] @@ -66,22 +65,6 @@ class GridConfig: # on hardware where the ACA-overlay per-cell DAG blows the kernel's # compile-time working set past device HBM. n_wage_res_batch_size: int = 0 - # Per-device chunk size for the simulate-side per-subject dispatch, - # keyed by `log_level`. Empty → 0 (no chunking) for every level. - # `log_level="off"` skips `validate_V` and its forced host-sync, which - # lets XLA pipeline across periods and reuse scratch — affordable - # chunk size grows. Use `get_subjects_batch_size(log_level)`. - subjects_batch_size_by_log_level: MappingProxyType[str, int] = field( - default_factory=lambda: MappingProxyType({}) - ) - - def get_subjects_batch_size(self, log_level: str) -> int: - """Return the per-device simulate chunk size for `log_level`. - - Returns 0 (no chunking) when this `GridConfig` defines no entry for - the given log level. - """ - return self.subjects_batch_size_by_log_level.get(log_level, 0) MODEL_CONFIG = ModelConfig() From 7af96820b1289c71adc70ea3d2d6320c646703d0 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Tue, 26 May 2026 16:38:42 +0200 Subject: [PATCH 21/21] Extend GridConfig with distributed flags for non-pref_type discrete states Add `lagged_labor_supply_distributed`, `claimed_ss_distributed`, and `spousal_income_distributed` to `GridConfig` and thread them through `build_states` into the inline-built `DiscreteGrid(...)` calls. Enables the 2x2 (lagged_labor_supply x claimed_ss) and 3-way (spousal_income) sharding configurations needed by the OOM / performance experiment matrix on Marvin. Co-Authored-By: Claude Opus 4.7 --- src/aca_model/baseline/regimes/_common.py | 12 ++++++-- src/aca_model/config.py | 17 +++++++---- tests/test_model_creation.py | 36 +++++++++++++++++++++++ 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index 143fd54..820f487 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -409,18 +409,24 @@ def build_states(spec: RegimeSpec, grids: Grids) -> dict: states["hcc_persistent"] = grids.hcc_persistent states["hcc_transitory"] = grids.hcc_transitory states["spousal_income"] = DiscreteGrid( - SpousalIncome, batch_size=gc.n_spousal_income_batch_size + SpousalIncome, + batch_size=gc.n_spousal_income_batch_size, + distributed=gc.spousal_income_distributed, ) states["pref_type"] = grids.pref_type if can_work: states["log_ft_wage_res"] = grids.wage_res if can_work and spec["his"] != "tied": states["lagged_labor_supply"] = DiscreteGrid( - LaggedLaborSupply, batch_size=gc.n_lagged_labor_supply_batch_size + LaggedLaborSupply, + batch_size=gc.n_lagged_labor_supply_batch_size, + distributed=gc.lagged_labor_supply_distributed, ) if spec["ss"] == "choose": states["claimed_ss"] = DiscreteGrid( - ClaimedSS, batch_size=gc.n_claimed_ss_batch_size + ClaimedSS, + batch_size=gc.n_claimed_ss_batch_size, + distributed=gc.claimed_ss_distributed, ) return states diff --git a/src/aca_model/config.py b/src/aca_model/config.py index a3186ae..d75e7ed 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -31,12 +31,19 @@ class GridConfig: # intermediate by 12x on hosts where the unsplayed kernel doesn't fit. n_assets_batch_size: int = 0 n_aime_batch_size: int = 1 - # Sharding flag for the `pref_type` discrete grid: pylcm distributes - # the grid across devices when `distributed=True`. Sharding is only - # supported on discrete state grids; continuous axes (`assets`, - # `aime`, `wage_res`, `hcc_*`) compile to an all-gather of the full - # V-array per device and are rejected at grid construction. + # Sharding flags for discrete state grids. pylcm distributes the + # grid across available devices when the flag is `True`. Sharding + # is only supported on discrete state grids; continuous axes + # (`assets`, `aime`, `wage_res`, `hcc_*`) compile to an all-gather + # of the full V-array per device and are rejected at grid + # construction. Mutually exclusive with `batch_size>0` on the same + # axis (pylcm rejects the combination). The non-`pref_type` flags + # route through `baseline/regimes/_common.py:build_states` to the + # inline-built `DiscreteGrid(...)` calls. pref_type_distributed: bool = False + lagged_labor_supply_distributed: bool = False + claimed_ss_distributed: bool = False + spousal_income_distributed: bool = False # `batch_size` on the inline-constructed discrete state grids — # health, spousal_income, lagged_labor_supply, claimed_ss. These # are read in `build_states` via `grids.grid_config`. Setting any diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index d19758c..1d89233 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -1,6 +1,7 @@ """Tests for baseline model creation and regime structure.""" from collections.abc import Mapping +from dataclasses import replace import pytest from helpers.model import make_aca_model, make_baseline_model @@ -274,3 +275,38 @@ def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" model = make_baseline_model(n_subjects=1) assert len(model.user_regimes) == 19 + + +@pytest.mark.parametrize( + ("config_field", "state_name"), + [ + ("lagged_labor_supply_distributed", "lagged_labor_supply"), + ("claimed_ss_distributed", "claimed_ss"), + ("spousal_income_distributed", "spousal_income"), + ], +) +def test_discrete_state_distributed_flag_propagates_to_regime( + config_field: str, state_name: str +) -> None: + """`GridConfig._distributed=True` sets `distributed=True` on the + `DiscreteGrid` for that axis in every regime that carries it.""" + gc = replace(BENCHMARK_GRID_CONFIG, **{config_field: True}) + grids = build_grids( + grid_config=gc, + fixed_params=_FIXED_PARAMS, + wage_params=_WAGE_PARAMS, + pref_type_grid=DiscreteGrid(BenchmarkPrefType), + ) + regime = _build_regime("retiree_dimc_choose_canwork", grids) + assert regime.states[state_name].distributed is True + + +@pytest.mark.parametrize( + "state_name", + ["lagged_labor_supply", "claimed_ss", "spousal_income"], +) +def test_discrete_state_distributed_flag_defaults_to_false(state_name: str) -> None: + """`distributed` on inline-built discrete states defaults to `False` so + configurations that do not opt in see no behaviour change.""" + regime = build_regime("retiree_dimc_choose_canwork") + assert regime.states[state_name].distributed is False