diff --git a/changelog.d/951.added b/changelog.d/951.added new file mode 100644 index 000000000..13dde67e6 --- /dev/null +++ b/changelog.d/951.added @@ -0,0 +1 @@ +Add worker-scoped setup and validation context contracts for local H5 builds. diff --git a/modal_app/fixtures/h5_cases.py b/modal_app/fixtures/h5_cases.py index cf06e723f..aaae1c5fe 100644 --- a/modal_app/fixtures/h5_cases.py +++ b/modal_app/fixtures/h5_cases.py @@ -9,7 +9,11 @@ import sqlite3 from dataclasses import dataclass from pathlib import Path -from typing import Any + +from policyengine_us_data.build_outputs.worker_inputs import ( + WorkerCalibrationInputPayload, + WorkerCalibrationInputs, +) FIXTURE_DATASET_PATH = Path( "/root/policyengine-us-data/tests/integration/test_fixture_50hh.h5" @@ -28,7 +32,7 @@ class SeededCase: """Description of one tiny end-to-end H5 test case.""" name: str - calibration_inputs: dict[str, Any] + calibration_inputs: WorkerCalibrationInputPayload expected_district_name: str = DISTRICT_NAME n_clones: int = N_CLONES seed: int = SEED @@ -192,13 +196,8 @@ def seed_case( weights_path = _write_weights(artifact_dir, n_records=n_records) db_path = _write_db(artifact_dir) - calibration_inputs: dict[str, Any] = { - "weights": str(weights_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": N_CLONES, - "seed": SEED, - } + geography_path = None + package_path = None if case_name == "saved_geography_success": geography_path = _write_saved_geography(artifact_dir, n_records=n_records) @@ -207,11 +206,9 @@ def seed_case( weights_path=weights_path, geography_path=geography_path, ) - calibration_inputs["geography"] = str(geography_path) elif case_name == "package_fallback_success": package_path = _write_calibration_package(artifact_dir, n_records=n_records) _write_run_config(artifact_dir, weights_path=weights_path) - calibration_inputs["calibration_package"] = str(package_path) elif case_name == "misnamed_package": _write_misnamed_package(artifact_dir, n_records=n_records) _write_run_config(artifact_dir, weights_path=weights_path) @@ -220,5 +217,13 @@ def seed_case( return SeededCase( name=case_name, - calibration_inputs=calibration_inputs, + calibration_inputs=WorkerCalibrationInputs( + weights_path=weights_path, + dataset_path=dataset_path, + database_path=db_path, + geography_path=geography_path, + calibration_package_path=package_path, + n_clones=N_CLONES, + seed=SEED, + ).to_wire_dict(), ) diff --git a/modal_app/h5_test_harness.py b/modal_app/h5_test_harness.py index cdcf27335..225dc2166 100644 --- a/modal_app/h5_test_harness.py +++ b/modal_app/h5_test_harness.py @@ -19,6 +19,10 @@ from modal_app.images import cpu_image as image # noqa: E402 from modal_app.local_area import VOLUME_MOUNT, pipeline_volume, staging_volume # noqa: E402 +from policyengine_us_data.build_outputs.worker_inputs import ( # noqa: E402 + WorkerCalibrationInputPayload, + WorkerCalibrationInputs, +) app = modal.App( @@ -70,19 +74,16 @@ def _calibration_inputs( calibration_package_path: Path | None = None, n_clones: int = 1, seed: int = 42, -) -> dict: - inputs = { - "weights": str(weights_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": n_clones, - "seed": seed, - } - if geography_path is not None: - inputs["geography"] = str(geography_path) - if calibration_package_path is not None: - inputs["calibration_package"] = str(calibration_package_path) - return inputs +) -> WorkerCalibrationInputPayload: + return WorkerCalibrationInputs( + weights_path=weights_path, + dataset_path=dataset_path, + database_path=db_path, + geography_path=geography_path, + calibration_package_path=calibration_package_path, + n_clones=n_clones, + seed=seed, + ).to_wire_dict() @app.function( @@ -290,17 +291,15 @@ def preflight_h5_case(run_id: str, *, n_clones: int = 1) -> dict: calibration_package_path=package_path if package_path.exists() else None, blocks_path=artifact_dir / "stacked_blocks.npy", ) - calibration_inputs = { - "weights": str(weights_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": n_clones, - "seed": SEED, - } - if geography_path.exists(): - calibration_inputs["geography"] = str(geography_path) - if package_path.exists(): - calibration_inputs["calibration_package"] = str(package_path) + calibration_inputs = WorkerCalibrationInputs.from_artifact_paths( + weights_path=weights_path, + dataset_path=dataset_path, + database_path=db_path, + geography_path=geography_path, + calibration_package_path=package_path, + n_clones=n_clones, + seed=SEED, + ).to_wire_dict() return { "fingerprint": fingerprint, "geography_source": resolved.kind if resolved is not None else None, diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 2b36b3cd2..89ce90b09 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -17,7 +17,7 @@ import sys import traceback from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Mapping import modal @@ -39,6 +39,9 @@ from policyengine_us_data.build_outputs.partitioning import ( # noqa: E402 partition_weighted_work_items, ) +from policyengine_us_data.build_outputs.worker_inputs import ( # noqa: E402 + WorkerCalibrationInputs, +) from policyengine_us_data.pipeline_metadata import pipeline_node # noqa: E402 from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402 from policyengine_us_data.utils.run_context import resolve_run_id # noqa: E402 @@ -489,6 +492,31 @@ def _build_worker_bootstrap( return bundle +def _build_worker_calibration_inputs( + *, + weights_path: Path, + geography_path: Path, + dataset_path: Path, + db_path: Path, + n_clones: int, + seed: int, + run_config_path: Path | None = None, + calibration_package_path: Path | None = None, +) -> WorkerCalibrationInputs: + """Build the normalized H5 worker input payload.""" + + return WorkerCalibrationInputs.from_artifact_paths( + weights_path=weights_path, + geography_path=geography_path, + dataset_path=dataset_path, + database_path=db_path, + n_clones=n_clones, + seed=seed, + run_config_path=run_config_path, + calibration_package_path=calibration_package_path, + ) + + @pipeline_node( PipelineNode( id="coordinate_work_partition", @@ -561,9 +589,10 @@ def run_phase( completed: set, branch: str, run_id: str, - calibration_inputs: Dict[str, str], + calibration_inputs: WorkerCalibrationInputs | Mapping[str, object], run_dir: Path, validate: bool = True, + scope_fingerprint: str | None = None, ) -> tuple: """Run a single build phase, spawning workers and collecting results. @@ -575,6 +604,9 @@ def run_phase( """ work_chunks = partition_work(work_items, num_workers, completed) total_remaining = sum(len(c) for c in work_chunks) + worker_input_payload = WorkerCalibrationInputs.from_wire_dict( + calibration_inputs + ).to_wire_dict() print(f"\n--- Phase: {phase_name} ---") print(f"Remaining work: {total_remaining} items across {len(work_chunks)} workers") @@ -590,9 +622,11 @@ def run_phase( handle = build_areas_worker.spawn( branch=branch, run_id=run_id, + scope="regional", work_items=chunk, - calibration_inputs=calibration_inputs, + calibration_inputs=worker_input_payload, validate=validate, + scope_fingerprint=scope_fingerprint, ) print(f" → fc: {handle.object_id}") handles.append(handle) @@ -685,9 +719,11 @@ def run_phase( def build_areas_worker( branch: str, run_id: str, + scope: str, work_items: List[Dict], - calibration_inputs: Dict[str, str], + calibration_inputs: WorkerCalibrationInputs | Mapping[str, object], validate: bool = True, + scope_fingerprint: str | None = None, ) -> Dict: """ Worker function that builds a subset of H5 files. @@ -702,33 +738,24 @@ def build_areas_worker( output_dir.mkdir(parents=True, exist_ok=True) work_items_json = json.dumps(work_items) + worker_inputs = WorkerCalibrationInputs.from_wire_dict(calibration_inputs) worker_cmd = [ *_python_cmd("-m", "modal_app.worker_script"), "--work-items", work_items_json, - "--weights-path", - calibration_inputs["weights"], - "--dataset-path", - calibration_inputs["dataset"], - "--db-path", - calibration_inputs["database"], + *worker_inputs.to_worker_cli_args(), "--output-dir", str(output_dir), + "--scope", + scope, + "--run-id", + run_id, + "--artifacts-dir", + str(Path("/pipeline/artifacts") / run_id), ] - if "geography" in calibration_inputs: - worker_cmd.extend(["--geography-path", calibration_inputs["geography"]]) - if "calibration_package" in calibration_inputs: - worker_cmd.extend( - [ - "--calibration-package-path", - calibration_inputs["calibration_package"], - ] - ) - if "n_clones" in calibration_inputs: - worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])]) - if "seed" in calibration_inputs: - worker_cmd.extend(["--seed", str(calibration_inputs["seed"])]) + if scope_fingerprint: + worker_cmd.extend(["--scope-fingerprint", scope_fingerprint]) repo_root = Path("/root/policyengine-us-data") cal_dir = repo_root / "policyengine_us_data" / "calibration" worker_cmd.extend( @@ -753,8 +780,10 @@ def build_areas_worker( env=os.environ.copy(), ) - if result.returncode != 0: + if result.stderr: print(f"Worker stderr:\n{result.stderr}", file=__import__("sys").stderr) + + if result.returncode != 0: return { "completed": [], "failed": [f"{item['type']}:{item['id']}" for item in work_items], @@ -1085,16 +1114,16 @@ def coordinate_publish( ) print("All required pipeline artifacts found on volume.") - calibration_inputs = { - "weights": str(weights_path), - "geography": str(geography_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": n_clones, - "seed": 42, - } - if calibration_package_path.exists(): - calibration_inputs["calibration_package"] = str(calibration_package_path) + calibration_inputs = _build_worker_calibration_inputs( + weights_path=weights_path, + geography_path=geography_path, + dataset_path=dataset_path, + db_path=db_path, + n_clones=n_clones, + seed=42, + run_config_path=config_json_path, + calibration_package_path=calibration_package_path, + ) validate_artifacts(config_json_path, artifacts) if validate: @@ -1215,6 +1244,7 @@ def coordinate_publish( calibration_inputs=calibration_inputs, run_dir=run_dir, validate=validate, + scope_fingerprint=fingerprint, ) accumulated_errors = [] @@ -1396,14 +1426,15 @@ def coordinate_national_publish( ) print("All required national pipeline artifacts found.") - calibration_inputs = { - "weights": str(weights_path), - "geography": str(geography_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": n_clones, - "seed": 42, - } + calibration_inputs = _build_worker_calibration_inputs( + weights_path=weights_path, + geography_path=geography_path, + dataset_path=dataset_path, + db_path=db_path, + n_clones=n_clones, + seed=42, + run_config_path=config_json_path, + ) validate_artifacts( config_json_path, artifacts, @@ -1444,9 +1475,11 @@ def coordinate_national_publish( worker_result = build_areas_worker.remote( branch=branch, run_id=run_id, + scope="national", work_items=work_items, - calibration_inputs=calibration_inputs, + calibration_inputs=calibration_inputs.to_wire_dict(), validate=validate, + scope_fingerprint=fingerprint, ) print( diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index 06f96f7e2..d89d22b1c 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -12,124 +12,6 @@ from pathlib import Path from typing import Any -import numpy as np - - -def _validate_in_subprocess( - h5_path, - area_type, - area_id, - display_id, - area_targets, - area_training, - constraints_map, - db_path, - period, -): - """Run validation for one area inside a subprocess. - - All Microsimulation memory is reclaimed when the - subprocess exits. - """ - import logging - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(message)s", - ) - from policyengine_us import Microsimulation - from sqlalchemy import create_engine as _ce - from policyengine_us_data.calibration.validate_staging import ( - validate_area, - _build_variable_entity_map, - ) - - engine = _ce(f"sqlite:///{db_path}") - sim = Microsimulation(dataset=h5_path) - variable_entity_map = _build_variable_entity_map(sim) - - results = validate_area( - sim=sim, - targets_df=area_targets, - engine=engine, - area_type=area_type, - area_id=area_id, - display_id=display_id, - dataset_path=h5_path, - period=period, - training_mask=area_training, - variable_entity_map=variable_entity_map, - constraints_map=constraints_map, - ) - return results - - -def _validate_h5_subprocess( - h5_path, - request, - validation_targets, - training_mask_full, - constraints_map, - db_path, - period, -): - """Spawn a subprocess to validate one H5 file. - - Uses multiprocessing spawn to isolate memory. - """ - import multiprocessing as _mp - - geo_level = request.validation_geo_level - geographic_ids = tuple(str(item) for item in request.validation_geographic_ids) - if geo_level is None: - return [] - area_type = { - "state": "states", - "district": "districts", - "city": "cities", - "national": "national", - }.get(request.area_type) - if area_type is None: - return [] - display_id = request.display_name - - # Filter targets to matching area - if request.area_type == "national": - mask = validation_targets["geo_level"] == geo_level - else: - mask = (validation_targets["geo_level"] == geo_level) & ( - validation_targets["geographic_id"].astype(str).isin(geographic_ids) - ) - - area_targets = validation_targets[mask].reset_index(drop=True) - area_training = training_mask_full[mask.values] - - if len(area_targets) == 0: - return [] - - # Filter constraints_map to relevant strata - area_strata = area_targets["stratum_id"].unique().tolist() - area_constraints = {int(s): constraints_map.get(int(s), []) for s in area_strata} - - ctx = _mp.get_context("spawn") - with ctx.Pool(1) as pool: - results = pool.apply( - _validate_in_subprocess, - ( - h5_path, - area_type, - request.area_id, - display_id, - area_targets, - area_training, - area_constraints, - db_path, - period, - ), - ) - - return results - def parse_args(argv: list[str] | None = None): """Parse worker arguments for legacy and typed request inputs.""" @@ -149,6 +31,32 @@ def parse_args(argv: list[str] | None = None): parser.add_argument("--dataset-path", required=True) parser.add_argument("--db-path", required=True) parser.add_argument("--output-dir", required=True) + parser.add_argument( + "--scope", + choices=("regional", "national"), + required=True, + help="Worker bootstrap scope to use for this request batch", + ) + parser.add_argument( + "--run-id", + default=None, + help="Pipeline run ID used for traceability and bootstrap lookup", + ) + parser.add_argument( + "--artifacts-dir", + default=None, + help="Optional run-scoped pipeline artifacts directory containing bootstrap artifacts", + ) + parser.add_argument( + "--run-config-path", + default=None, + help="Optional unified run configuration JSON used for traceability", + ) + parser.add_argument( + "--scope-fingerprint", + default=None, + help="Coordinator-resolved scope fingerprint expected by bootstrap artifacts", + ) parser.add_argument( "--geography-path", default=None, @@ -212,6 +120,34 @@ def _load_request_inputs_from_args( return "work_items", tuple(json.loads(args.work_items)) +def _build_publishing_inputs(*, args, run_id: str): + """Build the traceability input bundle consumed by worker setup services.""" + + from policyengine_us_data.build_outputs.worker_inputs import ( + WorkerCalibrationInputs, + ) + + worker_inputs = WorkerCalibrationInputs( + weights_path=Path(args.weights_path), + dataset_path=Path(args.dataset_path), + database_path=Path(args.db_path), + geography_path=( + Path(args.geography_path) if args.geography_path is not None else None + ), + calibration_package_path=( + Path(args.calibration_package_path) + if args.calibration_package_path is not None + else None + ), + run_config_path=( + Path(args.run_config_path) if args.run_config_path is not None else None + ), + n_clones=args.n_clones, + seed=args.seed, + ) + return worker_inputs.to_publishing_input_bundle(run_id=run_id) + + def _build_kwargs_from_request(request) -> dict[str, Any]: """Translate a typed request into `build_h5(...)` keyword arguments.""" @@ -296,13 +232,29 @@ def _resolve_request_input( return _request_key(request), request +def _log_worker_session_ready(*, scope: str, session, geography) -> None: + """Write worker-session setup details to stderr for Modal diagnostics.""" + + print( + "Worker session ready: " + f"scope={scope}, bootstrap={session.bootstrap_status}, " + f"{geography.n_clones} clones x {geography.n_records} records", + file=sys.stderr, + ) + bootstrap_error = session.caches.get("bootstrap_error") + if bootstrap_error: + print( + f"Worker bootstrap fallback reason: {bootstrap_error}", + file=sys.stderr, + ) + + def main(argv: list[str] | None = None): args = parse_args(argv) - weights_path = Path(args.weights_path) dataset_path = Path(args.dataset_path) - db_path = Path(args.db_path) output_dir = Path(args.output_dir) + run_id = args.run_id or output_dir.name or "local-worker" from policyengine_us_data.utils.takeup import ( SIMPLE_TAKEUP_VARS, @@ -315,100 +267,48 @@ def main(argv: list[str] | None = None): from policyengine_us_data.calibration.publish_local_area import ( build_h5, - load_calibration_geography, ) from policyengine_us_data.build_outputs.area_catalog import USAreaCatalog from policyengine_us_data.build_outputs.requests import AreaBuildRequest - - weights = np.load(weights_path) - - from policyengine_us import Microsimulation - - _sim = Microsimulation(dataset=str(dataset_path)) - n_records = len(_sim.calculate("household_id", map_to="household").values) - del _sim - - geography = load_calibration_geography( - weights_path=weights_path, - n_records=n_records, - n_clones=args.n_clones, - geography_path=( - Path(args.geography_path) if args.geography_path is not None else None - ), - calibration_package_path=( - Path(args.calibration_package_path) - if args.calibration_package_path is not None - else None - ), - ) - print( - f"Loaded geography: " - f"{geography.n_clones} clones x " - f"{geography.n_records} records", - file=sys.stderr, + from policyengine_us_data.build_outputs.validation import ( + AreaValidationService, + ValidationPolicy, ) + from policyengine_us_data.build_outputs.worker_session import WorkerSessionFactory + area_catalog = USAreaCatalog.default() request_input_mode, request_inputs = _load_request_inputs_from_args( args=args, area_build_request_cls=AreaBuildRequest, ) - - # ── Validation setup (once per worker) ── - validation_targets = None - training_mask_full = None - constraints_map = None - if not args.no_validate: - from sqlalchemy import create_engine - from policyengine_us_data.calibration.validate_staging import ( - _query_all_active_targets, - _batch_stratum_constraints, - ) - from policyengine_us_data.calibration.unified_calibration import ( - load_target_config, - _match_rules, - ) - - engine = create_engine(f"sqlite:///{db_path}") - validation_targets = _query_all_active_targets(engine, args.period) - print( - f"Loaded {len(validation_targets)} validation targets", - file=sys.stderr, - ) - - # Apply exclude/include from validation config - if args.validation_config: - val_cfg = load_target_config(args.validation_config) - exc_rules = val_cfg.get("exclude", []) - if exc_rules: - exc_mask = _match_rules(validation_targets, exc_rules) - validation_targets = validation_targets[~exc_mask].reset_index( - drop=True - ) - inc_rules = val_cfg.get("include", []) - if inc_rules: - inc_mask = _match_rules(validation_targets, inc_rules) - validation_targets = validation_targets[inc_mask].reset_index(drop=True) - - # Compute training mask from training config - if args.target_config: - tr_cfg = load_target_config(args.target_config) - tr_inc = tr_cfg.get("include", []) - if tr_inc: - training_mask_full = np.asarray( - _match_rules(validation_targets, tr_inc), - dtype=bool, - ) - else: - training_mask_full = np.ones(len(validation_targets), dtype=bool) - else: - training_mask_full = np.ones(len(validation_targets), dtype=bool) - - # Batch-load constraints - stratum_ids = validation_targets["stratum_id"].unique().tolist() - constraints_map = _batch_stratum_constraints(engine, stratum_ids) + scope = args.scope + inputs = _build_publishing_inputs(args=args, run_id=run_id) + validation_service = AreaValidationService() + + session = WorkerSessionFactory(validation_service=validation_service).create( + inputs=inputs, + scope=scope, + validation_policy=ValidationPolicy(enabled=not args.no_validate), + period=args.period, + target_config_path=Path(args.target_config) if args.target_config else None, + validation_config_path=( + Path(args.validation_config) if args.validation_config else None + ), + artifacts_dir=Path(args.artifacts_dir) if args.artifacts_dir else None, + expected_scope_fingerprint=args.scope_fingerprint, + ) + weights = session.weights.values + n_records = session.weights.n_records + geography = session.geography + validation_context = session.validation_context + _log_worker_session_ready(scope=scope, session=session, geography=geography) + if ( + validation_context is not None + and validation_context.validation_targets is not None + ): print( - f"Validation ready: {len(validation_targets)} targets, " - f"{len(stratum_ids)} strata", + f"Validation ready: {len(validation_context.validation_targets)} targets, " + f"{len(validation_context.constraints_map or {})} strata", file=sys.stderr, ) @@ -477,42 +377,22 @@ def main(argv: list[str] | None = None): file=sys.stderr, ) - # ── Per-item validation ── - if not args.no_validate and validation_targets is not None: + if not args.no_validate and validation_context is not None: try: - v_rows = _validate_h5_subprocess( + validation_result = validation_service.validate_request( + context=validation_context, h5_path=str(path), request=request, - validation_targets=validation_targets, - training_mask_full=training_mask_full, - constraints_map=constraints_map, - db_path=str(db_path), - period=args.period, ) + v_rows = list(validation_result.rows) results["validation_rows"].extend(v_rows) - n_fail = sum( - 1 for r in v_rows if r.get("sanity_check") == "FAIL" - ) - rae_vals = [ - r["rel_abs_error"] - for r in v_rows - if isinstance( - r.get("rel_abs_error"), - (int, float), - ) - and r["rel_abs_error"] != float("inf") - ] - mean_rae = sum(rae_vals) / len(rae_vals) if rae_vals else 0.0 - results["validation_summary"][request_key] = { - "n_targets": len(v_rows), - "n_sanity_fail": n_fail, - "mean_rel_abs_error": round(mean_rae, 4), - } + summary = dict(validation_result.summary) + results["validation_summary"][request_key] = summary print( f" Validated {request_key}: " - f"{len(v_rows)} targets, " - f"{n_fail} sanity fails, " - f"mean RAE={mean_rae:.4f}", + f"{summary['n_targets']} targets, " + f"{summary['n_sanity_fail']} sanity fails, " + f"mean RAE={summary['mean_rel_abs_error']:.4f}", file=sys.stderr, ) except Exception as ve: diff --git a/policyengine_us_data/build_outputs/__init__.py b/policyengine_us_data/build_outputs/__init__.py index b2a67b265..6f2ec1c4c 100644 --- a/policyengine_us_data/build_outputs/__init__.py +++ b/policyengine_us_data/build_outputs/__init__.py @@ -4,5 +4,6 @@ seams rather than speculative placeholders. The current early slices support H5 output request construction, exact calibration geography loading, fingerprinting, clone-weight shape contracts, worker partitioning, source -dataset snapshot contracts, and introduced worker-bootstrap artifacts. +dataset snapshot contracts, worker input normalization, worker-bootstrap +artifacts, and worker-scoped session and validation context setup. """ diff --git a/policyengine_us_data/build_outputs/source_dataset.py b/policyengine_us_data/build_outputs/source_dataset.py index 3f1bd08d2..37d259189 100644 --- a/policyengine_us_data/build_outputs/source_dataset.py +++ b/policyengine_us_data/build_outputs/source_dataset.py @@ -556,3 +556,32 @@ def load(self, dataset_path: Path) -> SourceDatasetSnapshot: path = Path(dataset_path) simulation = Microsimulation(dataset=str(path)) return SourceDatasetSnapshot.from_simulation(path, simulation) + + def load_with_entity_graph( + self, + dataset_path: Path, + entity_graph: EntityGraph, + ) -> SourceDatasetSnapshot: + """Open a source H5 dataset using a prebuilt structural entity graph. + + Args: + dataset_path: Source H5 dataset path. + entity_graph: Persisted structural entity graph for the dataset. + + Returns: + A `SourceDatasetSnapshot` backed by a PolicyEngine microsimulation + and the supplied entity graph. + """ + + from policyengine_us import Microsimulation + + path = Path(dataset_path) + simulation = Microsimulation(dataset=str(path)) + provider = MicrosimulationVariableProvider(simulation) + return SourceDatasetSnapshot( + dataset_path=path, + time_period=int(simulation.default_calculation_period), + entity_graph=entity_graph, + input_variables=provider.input_variables, + variable_provider=provider, + ) diff --git a/policyengine_us_data/build_outputs/validation.py b/policyengine_us_data/build_outputs/validation.py new file mode 100644 index 000000000..534f53045 --- /dev/null +++ b/policyengine_us_data/build_outputs/validation.py @@ -0,0 +1,490 @@ +"""Worker-scoped validation context for local H5 publication.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Mapping + +import numpy as np + +from policyengine_us_data.pipeline_metadata import pipeline_node + +from .fingerprinting import PublishingInputBundle + +__all__ = [ + "AreaValidationService", + "ValidationContext", + "AreaValidationResult", + "ValidationPolicy", +] + + +@pipeline_node( + id="local_h5_validation_policy", + label="ValidationPolicy", + node_type="library", + description="Worker-scoped local H5 validation policy contract.", + source_file="policyengine_us_data/build_outputs/validation.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=["uv run pytest tests/unit/build_outputs/test_validation.py"], +) +@dataclass(frozen=True) +class ValidationPolicy: + """Validation switch for a local H5 worker session.""" + + enabled: bool = True + + +@pipeline_node( + id="local_h5_validation_context", + label="ValidationContext", + node_type="library", + description="Prepared per-worker local H5 validation target context.", + source_file="policyengine_us_data/build_outputs/validation.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=["uv run pytest tests/unit/build_outputs/test_validation.py"], +) +@dataclass(frozen=True) +class ValidationContext: + """Prepared validation data reused across all requests in one worker.""" + + policy: ValidationPolicy + target_db_path: Path | None + period: int + validation_targets: Any = None + training_mask: np.ndarray | None = None + constraints_map: Mapping[int, Any] | None = None + target_config_path: Path | None = None + validation_config_path: Path | None = None + + def __post_init__(self) -> None: + target_db_path = ( + Path(self.target_db_path) if self.target_db_path is not None else None + ) + target_config_path = ( + Path(self.target_config_path) + if self.target_config_path is not None + else None + ) + validation_config_path = ( + Path(self.validation_config_path) + if self.validation_config_path is not None + else None + ) + object.__setattr__(self, "target_db_path", target_db_path) + object.__setattr__(self, "period", int(self.period)) + object.__setattr__(self, "target_config_path", target_config_path) + object.__setattr__(self, "validation_config_path", validation_config_path) + if self.training_mask is not None: + object.__setattr__( + self, + "training_mask", + np.asarray(self.training_mask, dtype=bool), + ) + if self.constraints_map is not None: + object.__setattr__( + self, + "constraints_map", + {int(key): value for key, value in self.constraints_map.items()}, + ) + + +@dataclass(frozen=True) +class AreaValidationResult: + """Validation rows and summary for one built area H5.""" + + rows: tuple[Mapping[str, Any], ...] = () + summary: Mapping[str, Any] | None = None + + def __post_init__(self) -> None: + rows = tuple(self.rows) + summary = ( + dict(self.summary) + if self.summary is not None + else _validation_summary(rows) + ) + object.__setattr__(self, "rows", rows) + object.__setattr__(self, "summary", summary) + + +@pipeline_node( + id="local_h5_area_validation_service", + label="AreaValidationService", + node_type="library", + description="Prepare local H5 validation targets once per worker session.", + source_file="policyengine_us_data/build_outputs/validation.py", + status="current", + stability="moving", + pathways=["local_h5"], + artifacts_in=["policy_data.db", "target_config.yaml", "target_config_full.yaml"], + validation_commands=["uv run pytest tests/unit/build_outputs/test_validation.py"], +) +class AreaValidationService: + """Build validation state for all H5 requests handled by one worker.""" + + def __init__( + self, + *, + engine_factory: Callable[[str], Any] | None = None, + query_targets: Callable[[Any, int], Any] | None = None, + batch_constraints: Callable[[Any, list[int]], Mapping[int, Any]] | None = None, + load_target_config: Callable[[Path | str], Mapping[str, Any]] | None = None, + match_rules: Callable[[Any, list[Mapping[str, Any]]], Any] | None = None, + validate_h5: Callable[..., list[Mapping[str, Any]]] | None = None, + ) -> None: + """Create a validation service with injectable seams for tests.""" + + self._engine_factory = engine_factory + self._query_targets = query_targets + self._batch_constraints = batch_constraints + self._load_target_config = load_target_config + self._match_rules = match_rules + self._validate_h5 = validate_h5 + + def prepare_context( + self, + *, + inputs: PublishingInputBundle, + policy: ValidationPolicy, + period: int, + target_config_path: Path | None = None, + validation_config_path: Path | None = None, + ) -> ValidationContext | None: + """Load validation targets and constraints once for a worker. + + Returns `None` when validation is disabled. When validation is enabled + but no target database path exists, this returns an empty context so + callers can still inspect the policy and configured paths. + """ + + if not policy.enabled: + return None + + if inputs.target_db_path is None: + return ValidationContext( + policy=policy, + target_db_path=None, + period=period, + target_config_path=target_config_path, + validation_config_path=validation_config_path, + ) + + engine = self._create_engine(Path(inputs.target_db_path)) + try: + validation_targets = self._query_all_targets(engine, period) + validation_targets = self._apply_validation_rules( + validation_targets, + validation_config_path, + ) + training_mask = self._training_mask( + validation_targets, + target_config_path, + ) + stratum_ids = [ + int(item) for item in validation_targets["stratum_id"].unique().tolist() + ] + constraints_map = self._load_constraints(engine, stratum_ids) + finally: + dispose = getattr(engine, "dispose", None) + if callable(dispose): + dispose() + + return ValidationContext( + policy=policy, + target_db_path=Path(inputs.target_db_path), + period=period, + validation_targets=validation_targets, + training_mask=training_mask, + constraints_map=constraints_map, + target_config_path=target_config_path, + validation_config_path=validation_config_path, + ) + + def validate_request( + self, + *, + context: ValidationContext | None, + h5_path: Path | str, + request, + ) -> AreaValidationResult: + """Validate one built area H5 against a prepared worker context.""" + + if ( + context is None + or context.validation_targets is None + or context.target_db_path is None + ): + return AreaValidationResult() + + geo_level = request.validation_geo_level + geographic_ids = tuple(str(item) for item in request.validation_geographic_ids) + if geo_level is None: + return AreaValidationResult() + + area_type = { + "state": "states", + "district": "districts", + "city": "cities", + "national": "national", + }.get(request.area_type) + if area_type is None: + return AreaValidationResult() + + targets = context.validation_targets + if request.area_type == "national": + mask = targets["geo_level"] == geo_level + else: + mask = (targets["geo_level"] == geo_level) & ( + targets["geographic_id"].astype(str).isin(geographic_ids) + ) + + area_targets = targets[mask].reset_index(drop=True) + if len(area_targets) == 0: + return AreaValidationResult() + + area_training = self._area_training_mask(context.training_mask, mask) + area_constraints = self._area_constraints( + area_targets, + context.constraints_map or {}, + ) + rows = tuple( + self._run_validation( + h5_path=str(h5_path), + request=request, + area_type=area_type, + area_targets=area_targets, + area_training=area_training, + constraints_map=area_constraints, + db_path=str(context.target_db_path), + period=context.period, + ) + ) + return AreaValidationResult( + rows=rows, + summary=_validation_summary(rows), + ) + + def _create_engine(self, target_db_path: Path): + if self._engine_factory is not None: + return self._engine_factory(f"sqlite:///{target_db_path}") + + from sqlalchemy import create_engine + + return create_engine(f"sqlite:///{target_db_path}") + + def _query_all_targets(self, engine, period: int): + if self._query_targets is not None: + return self._query_targets(engine, int(period)) + + from policyengine_us_data.calibration.validate_staging import ( + _query_all_active_targets, + ) + + return _query_all_active_targets(engine, int(period)) + + def _load_constraints(self, engine, stratum_ids: list[int]): + if self._batch_constraints is not None: + return self._batch_constraints(engine, stratum_ids) + + from policyengine_us_data.calibration.validate_staging import ( + _batch_stratum_constraints, + ) + + return _batch_stratum_constraints(engine, stratum_ids) + + def _config(self, path: Path | None) -> Mapping[str, Any]: + if path is None: + return {} + + if self._load_target_config is not None: + return self._load_target_config(path) + + from policyengine_us_data.calibration.unified_calibration import ( + load_target_config, + ) + + return load_target_config(path) + + def _match(self, targets, rules: list[Mapping[str, Any]]): + if self._match_rules is not None: + return self._match_rules(targets, rules) + + from policyengine_us_data.calibration.unified_calibration import _match_rules + + return _match_rules(targets, rules) + + def _apply_validation_rules(self, validation_targets, config_path: Path | None): + config = self._config(config_path) + exclude_rules = list(config.get("exclude", [])) + if exclude_rules: + exclude_mask = self._match(validation_targets, exclude_rules) + validation_targets = validation_targets[~exclude_mask].reset_index( + drop=True + ) + + include_rules = list(config.get("include", [])) + if include_rules: + include_mask = self._match(validation_targets, include_rules) + validation_targets = validation_targets[include_mask].reset_index(drop=True) + + return validation_targets + + def _training_mask(self, validation_targets, config_path: Path | None): + config = self._config(config_path) + include_rules = list(config.get("include", [])) + if not include_rules: + return np.ones(len(validation_targets), dtype=bool) + return np.asarray(self._match(validation_targets, include_rules), dtype=bool) + + def _area_training_mask(self, training_mask, area_mask): + if training_mask is None: + return np.ones(int(np.sum(area_mask.to_numpy())), dtype=bool) + return training_mask[area_mask.values] + + def _area_constraints( + self, + area_targets, + constraints_map: Mapping[int, Any], + ) -> dict[int, Any]: + area_strata = area_targets["stratum_id"].unique().tolist() + return { + int(stratum): constraints_map.get(int(stratum), []) + for stratum in area_strata + } + + def _run_validation( + self, + *, + h5_path: str, + request, + area_type: str, + area_targets, + area_training, + constraints_map: Mapping[int, Any], + db_path: str, + period: int, + ) -> list[Mapping[str, Any]]: + if self._validate_h5 is not None: + return self._validate_h5( + h5_path=h5_path, + request=request, + area_type=area_type, + area_targets=area_targets, + area_training=area_training, + constraints_map=constraints_map, + db_path=db_path, + period=period, + ) + return _validate_h5_subprocess( + h5_path=h5_path, + request=request, + area_type=area_type, + area_targets=area_targets, + area_training=area_training, + constraints_map=constraints_map, + db_path=db_path, + period=period, + ) + + +def _validation_summary(rows: tuple[Mapping[str, Any], ...]) -> dict[str, Any]: + n_fail = sum(1 for row in rows if row.get("sanity_check") == "FAIL") + rae_values = [ + row["rel_abs_error"] + for row in rows + if isinstance(row.get("rel_abs_error"), (int, float)) + and row["rel_abs_error"] != float("inf") + ] + mean_rae = sum(rae_values) / len(rae_values) if rae_values else 0.0 + return { + "n_targets": len(rows), + "n_sanity_fail": n_fail, + "mean_rel_abs_error": round(mean_rae, 4), + } + + +def _validate_h5_subprocess( + *, + h5_path: str, + request, + area_type: str, + area_targets, + area_training, + constraints_map: Mapping[int, Any], + db_path: str, + period: int, +) -> list[Mapping[str, Any]]: + """Spawn a subprocess to validate one H5 file and reclaim simulation memory.""" + + import multiprocessing as mp + + ctx = mp.get_context("spawn") + with ctx.Pool(1) as pool: + return pool.apply( + _validate_area_in_subprocess, + ( + h5_path, + area_type, + request.area_id, + request.display_name, + area_targets, + area_training, + constraints_map, + db_path, + period, + ), + ) + + +def _validate_area_in_subprocess( + h5_path, + area_type, + area_id, + display_id, + area_targets, + area_training, + constraints_map, + db_path, + period, +): + """Run validation for one area in an isolated subprocess.""" + + import logging + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + ) + from policyengine_us import Microsimulation + from sqlalchemy import create_engine + from policyengine_us_data.calibration.validate_staging import ( + _build_variable_entity_map, + validate_area, + ) + + engine = create_engine(f"sqlite:///{db_path}") + try: + sim = Microsimulation(dataset=h5_path) + variable_entity_map = _build_variable_entity_map(sim) + return validate_area( + sim=sim, + targets_df=area_targets, + engine=engine, + area_type=area_type, + area_id=area_id, + display_id=display_id, + dataset_path=h5_path, + period=period, + training_mask=area_training, + variable_entity_map=variable_entity_map, + constraints_map=constraints_map, + ) + finally: + dispose = getattr(engine, "dispose", None) + if callable(dispose): + dispose() diff --git a/policyengine_us_data/build_outputs/worker_inputs.py b/policyengine_us_data/build_outputs/worker_inputs.py new file mode 100644 index 000000000..c0896597b --- /dev/null +++ b/policyengine_us_data/build_outputs/worker_inputs.py @@ -0,0 +1,231 @@ +"""Normalized input payloads for local H5 worker execution.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Mapping, TypeAlias + +from policyengine_us_data.pipeline_metadata import pipeline_node + +from .fingerprinting import PublishingInputBundle + +WorkerCalibrationInputValue: TypeAlias = str | int +WorkerCalibrationInputPayload: TypeAlias = dict[ + str, + WorkerCalibrationInputValue, +] + + +def _coerce_path(value: object, *, field_name: str) -> Path: + """Return a path from a wire value or raise a clear contract error.""" + + if isinstance(value, Path): + return value + if isinstance(value, str): + return Path(value) + raise TypeError(f"{field_name} must be a path string, got {type(value).__name__}") + + +def _coerce_optional_path(value: object, *, field_name: str) -> Path | None: + """Return an optional path from a wire value.""" + + if value is None: + return None + return _coerce_path(value, field_name=field_name) + + +def _coerce_int(value: object, *, field_name: str) -> int: + """Return an integer from a wire value or raise a clear contract error.""" + + if isinstance(value, bool): + raise TypeError(f"{field_name} must be an int, got bool") + if isinstance(value, int): + return value + if isinstance(value, str): + return int(value) + raise TypeError(f"{field_name} must be an int, got {type(value).__name__}") + + +@pipeline_node( + id="local_h5_worker_calibration_inputs", + label="WorkerCalibrationInputs", + node_type="library", + description="Normalized worker-execution input payload for local H5 builds.", + source_file="policyengine_us_data/build_outputs/worker_inputs.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_inputs.py" + ], +) +@dataclass(frozen=True) +class WorkerCalibrationInputs: + """Input artifact paths and runtime settings for one H5 worker batch. + + This is the typed library contract. Modal entrypoints may still exchange the + `to_wire_dict()` representation because it is easier to serialize and + inspect, but all producers and consumers should normalize through this + class before using field values. + """ + + weights_path: Path + dataset_path: Path + database_path: Path + geography_path: Path | None = None + calibration_package_path: Path | None = None + run_config_path: Path | None = None + n_clones: int = 430 + seed: int = 42 + + @classmethod + def from_artifact_paths( + cls, + *, + weights_path: Path, + dataset_path: Path, + database_path: Path, + geography_path: Path | None = None, + calibration_package_path: Path | None = None, + run_config_path: Path | None = None, + n_clones: int = 430, + seed: int = 42, + require_optional_paths_exist: bool = True, + ) -> "WorkerCalibrationInputs": + """Build worker inputs from coordinator artifact paths. + + Optional paths are included only when present by default, matching the + previous coordinator behavior for run configs and calibration packages. + """ + + if require_optional_paths_exist: + geography_path = geography_path if _exists(geography_path) else None + calibration_package_path = ( + calibration_package_path if _exists(calibration_package_path) else None + ) + run_config_path = run_config_path if _exists(run_config_path) else None + + return cls( + weights_path=weights_path, + dataset_path=dataset_path, + database_path=database_path, + geography_path=geography_path, + calibration_package_path=calibration_package_path, + run_config_path=run_config_path, + n_clones=n_clones, + seed=seed, + ) + + @classmethod + def from_wire_dict( + cls, + payload: Mapping[str, object] | "WorkerCalibrationInputs", + ) -> "WorkerCalibrationInputs": + """Normalize a Modal-safe worker input payload.""" + + if isinstance(payload, cls): + return payload + + missing = [ + key for key in ("weights", "dataset", "database") if key not in payload + ] + if missing: + raise KeyError( + "Missing required worker calibration input(s): " + ", ".join(missing) + ) + + return cls( + weights_path=_coerce_path(payload["weights"], field_name="weights"), + dataset_path=_coerce_path(payload["dataset"], field_name="dataset"), + database_path=_coerce_path(payload["database"], field_name="database"), + geography_path=_coerce_optional_path( + payload.get("geography"), + field_name="geography", + ), + calibration_package_path=_coerce_optional_path( + payload.get("calibration_package"), + field_name="calibration_package", + ), + run_config_path=_coerce_optional_path( + payload.get("run_config"), + field_name="run_config", + ), + n_clones=_coerce_int(payload.get("n_clones", 430), field_name="n_clones"), + seed=_coerce_int(payload.get("seed", 42), field_name="seed"), + ) + + def to_wire_dict(self) -> WorkerCalibrationInputPayload: + """Return the Modal-safe payload used by remote worker entrypoints.""" + + payload: WorkerCalibrationInputPayload = { + "weights": str(self.weights_path), + "dataset": str(self.dataset_path), + "database": str(self.database_path), + "n_clones": self.n_clones, + "seed": self.seed, + } + if self.geography_path is not None: + payload["geography"] = str(self.geography_path) + if self.calibration_package_path is not None: + payload["calibration_package"] = str(self.calibration_package_path) + if self.run_config_path is not None: + payload["run_config"] = str(self.run_config_path) + return payload + + def to_worker_cli_args(self) -> list[str]: + """Return worker_script CLI arguments for these inputs.""" + + args = [ + "--weights-path", + str(self.weights_path), + "--dataset-path", + str(self.dataset_path), + "--db-path", + str(self.database_path), + "--n-clones", + str(self.n_clones), + "--seed", + str(self.seed), + ] + if self.geography_path is not None: + args.extend(["--geography-path", str(self.geography_path)]) + if self.calibration_package_path is not None: + args.extend( + [ + "--calibration-package-path", + str(self.calibration_package_path), + ] + ) + if self.run_config_path is not None: + args.extend(["--run-config-path", str(self.run_config_path)]) + return args + + def to_publishing_input_bundle( + self, + *, + run_id: str, + version: str = "", + legacy_blocks_path: Path | None = None, + ) -> PublishingInputBundle: + """Return the traceability/fingerprinting input bundle.""" + + return PublishingInputBundle( + weights_path=self.weights_path, + source_dataset_path=self.dataset_path, + target_db_path=self.database_path, + exact_geography_path=self.geography_path, + calibration_package_path=self.calibration_package_path, + run_config_path=self.run_config_path, + run_id=run_id, + version=version, + n_clones=self.n_clones, + seed=self.seed, + legacy_blocks_path=legacy_blocks_path, + ) + + +def _exists(path: Path | None) -> bool: + """Return whether an optional artifact path exists.""" + + return path is not None and path.exists() diff --git a/policyengine_us_data/build_outputs/worker_session.py b/policyengine_us_data/build_outputs/worker_session.py new file mode 100644 index 000000000..855fe62e9 --- /dev/null +++ b/policyengine_us_data/build_outputs/worker_session.py @@ -0,0 +1,311 @@ +"""Worker-scoped local H5 setup contracts.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +import numpy as np + +from policyengine_us_data.pipeline_metadata import pipeline_node + +from .bootstrap import ( + BootstrapScope, + WorkerBootstrapBundle, + WorkerBootstrapStore, + load_entity_graph, +) +from .fingerprinting import PublishingInputBundle +from .geography_loader import CalibrationGeographyLoader +from .source_dataset import PolicyEngineDatasetReader, SourceDatasetSnapshot +from .validation import AreaValidationService, ValidationContext, ValidationPolicy +from .weights import CloneWeightMatrix + +BootstrapStatus = Literal["used", "fallback", "unavailable"] + +__all__ = [ + "BootstrapStatus", + "WorkerSession", + "WorkerSessionFactory", +] + + +@pipeline_node( + id="local_h5_worker_session", + label="WorkerSession", + node_type="library", + description="Worker-scoped local H5 setup state reused across queued requests.", + source_file="policyengine_us_data/build_outputs/worker_session.py", + status="current", + stability="moving", + pathways=["local_h5"], + artifacts_in=[ + "calibration_weights.npy", + "source_imputed_stratified_extended_cps.h5", + "geography_assignment.npz", + "policy_data.db", + "worker_bootstrap.json", + ], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_session.py" + ], +) +@dataclass +class WorkerSession: + """Prepared local H5 state for one worker process.""" + + inputs: PublishingInputBundle + scope: BootstrapScope + source: SourceDatasetSnapshot + weights: CloneWeightMatrix + geography: Any + validation_context: ValidationContext | None = None + bootstrap_bundle: WorkerBootstrapBundle | None = None + bootstrap_status: BootstrapStatus = "unavailable" + caches: dict[str, Any] = field(default_factory=dict) + + +@pipeline_node( + id="local_h5_worker_session_factory", + label="WorkerSessionFactory", + node_type="library", + description="Load local H5 source, weights, geography, and validation context once per worker.", + source_file="policyengine_us_data/build_outputs/worker_session.py", + status="current", + stability="moving", + pathways=["local_h5"], + artifacts_in=[ + "calibration_weights.npy", + "source_imputed_stratified_extended_cps.h5", + "geography_assignment.npz", + "policy_data.db", + "bootstrap/{scope}/worker_bootstrap.json", + "bootstrap/{scope}/entity_graph.npz", + ], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_session.py", + "uv run pytest tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py", + ], +) +class WorkerSessionFactory: + """Build worker-scoped setup from raw inputs or persisted bootstrap facts.""" + + def __init__( + self, + *, + dataset_reader: PolicyEngineDatasetReader | None = None, + geography_loader: CalibrationGeographyLoader | None = None, + validation_service: AreaValidationService | None = None, + bootstrap_store: WorkerBootstrapStore | None = None, + ) -> None: + """Create a session factory with injectable seams for tests.""" + + self._dataset_reader = dataset_reader or PolicyEngineDatasetReader() + self._geography_loader = geography_loader or CalibrationGeographyLoader() + self._validation_service = validation_service or AreaValidationService() + self._bootstrap_store = bootstrap_store + + def create( + self, + *, + inputs: PublishingInputBundle, + scope: BootstrapScope, + validation_policy: ValidationPolicy | None = None, + period: int = 2024, + target_config_path: Path | None = None, + validation_config_path: Path | None = None, + artifacts_dir: Path | None = None, + expected_scope_fingerprint: str | None = None, + ) -> WorkerSession: + """Create a worker session for one local H5 scope. + + Bootstrap artifacts are preferred when present. If they are missing, + stale, or unreadable, the factory falls back to raw source loaders so + rollout can remain dual-path until the bootstrap contract is mandatory. + """ + + bootstrap_store = self._bootstrap_store + if bootstrap_store is None and artifacts_dir is not None: + bootstrap_store = WorkerBootstrapStore(artifacts_dir) + + bundle, bootstrap_error = self._load_bootstrap( + bootstrap_store=bootstrap_store, + scope=scope, + ) + if bundle is not None: + bootstrap_error = self._validate_bootstrap_bundle( + bundle=bundle, + inputs=inputs, + scope=scope, + expected_scope_fingerprint=expected_scope_fingerprint, + ) + if bootstrap_error is not None: + bundle = None + source, bootstrap_status, source_error = self._load_source( + inputs=inputs, + bundle=bundle, + ) + if bootstrap_error is not None and bootstrap_status == "unavailable": + bootstrap_status = "fallback" + fallback_error = source_error or bootstrap_error + + weights = self._load_weights(inputs=inputs, source=source) + geography = self._geography_loader.load( + weights_path=inputs.weights_path, + n_records=weights.n_records, + n_clones=weights.n_clones, + geography_path=inputs.exact_geography_path, + blocks_path=inputs.legacy_blocks_path, + calibration_package_path=inputs.calibration_package_path, + ) + + policy = validation_policy or ValidationPolicy() + validation_context = self._validation_service.prepare_context( + inputs=inputs, + policy=policy, + period=period, + target_config_path=target_config_path, + validation_config_path=validation_config_path, + ) + + caches: dict[str, Any] = {} + if fallback_error is not None: + caches["bootstrap_error"] = str(fallback_error) + + return WorkerSession( + inputs=inputs, + scope=scope, + source=source, + weights=weights, + geography=geography, + validation_context=validation_context, + bootstrap_bundle=bundle if bootstrap_status == "used" else None, + bootstrap_status=bootstrap_status, + caches=caches, + ) + + def _load_bootstrap( + self, + *, + bootstrap_store: WorkerBootstrapStore | None, + scope: BootstrapScope, + ) -> tuple[WorkerBootstrapBundle | None, Exception | None]: + if bootstrap_store is None: + return None, None + + manifest_path = getattr(bootstrap_store, "manifest_path", None) + manifest_exists = False + if callable(manifest_path): + manifest_exists = Path(manifest_path(scope)).exists() + if not manifest_exists: + return None, None + + try: + return bootstrap_store.load(scope), None + except FileNotFoundError as exc: + return None, exc if manifest_exists else None + except Exception as exc: + return None, exc + + def _validate_bootstrap_bundle( + self, + *, + bundle: WorkerBootstrapBundle, + inputs: PublishingInputBundle, + scope: BootstrapScope, + expected_scope_fingerprint: str | None, + ) -> Exception | None: + try: + self._raise_for_bootstrap_mismatch( + bundle=bundle, + inputs=inputs, + scope=scope, + expected_scope_fingerprint=expected_scope_fingerprint, + ) + except Exception as exc: + return exc + return None + + def _raise_for_bootstrap_mismatch( + self, + *, + bundle: WorkerBootstrapBundle, + inputs: PublishingInputBundle, + scope: BootstrapScope, + expected_scope_fingerprint: str | None, + ) -> None: + if bundle.run_id != inputs.run_id: + raise ValueError( + f"Bootstrap run_id {bundle.run_id!r} does not match " + f"worker run_id {inputs.run_id!r}" + ) + if bundle.scope != scope: + raise ValueError( + f"Bootstrap scope {bundle.scope!r} does not match " + f"worker scope {scope!r}" + ) + + if not bundle.entity_graph_path.exists(): + raise FileNotFoundError( + f"Bootstrap entity graph not found: {bundle.entity_graph_path}" + ) + + actual_scope_fingerprint = bundle.traceability.get("scope_fingerprint") + if expected_scope_fingerprint is None: + raise ValueError( + "Bootstrap scope fingerprint cannot be validated without an " + "expected scope fingerprint" + ) + + if actual_scope_fingerprint != expected_scope_fingerprint: + raise ValueError( + f"Bootstrap scope fingerprint {actual_scope_fingerprint!r} " + f"does not match expected fingerprint {expected_scope_fingerprint!r}" + ) + + def _load_source( + self, + *, + inputs: PublishingInputBundle, + bundle: WorkerBootstrapBundle | None, + ) -> tuple[SourceDatasetSnapshot, BootstrapStatus, Exception | None]: + if bundle is not None: + try: + entity_graph = load_entity_graph(bundle.entity_graph_path) + load_with_entity_graph = getattr( + self._dataset_reader, + "load_with_entity_graph", + ) + return ( + load_with_entity_graph( + inputs.source_dataset_path, + entity_graph, + ), + "used", + None, + ) + except Exception as exc: + source = self._dataset_reader.load(inputs.source_dataset_path) + return source, "fallback", exc + + source = self._dataset_reader.load(inputs.source_dataset_path) + return source, "unavailable", None + + def _load_weights( + self, + *, + inputs: PublishingInputBundle, + source: SourceDatasetSnapshot, + ) -> CloneWeightMatrix: + weights_array = np.load(inputs.weights_path) + weights = CloneWeightMatrix.from_vector( + weights_array, + n_records=source.n_households, + ) + if inputs.n_clones is not None and weights.n_clones != int(inputs.n_clones): + raise ValueError( + f"Weight vector implies n_clones={weights.n_clones}, " + f"expected {inputs.n_clones}" + ) + return weights diff --git a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py index 7365805d8..5cfb7635e 100644 --- a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py +++ b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py @@ -8,11 +8,13 @@ import numpy as np import pytest +from policyengine_us_data.build_outputs.bootstrap import WorkerBootstrapBuilder from policyengine_us_data.build_outputs.source_dataset import ( DEFAULT_SUBENTITIES, PolicyEngineDatasetReader, ) from policyengine_us_data.build_outputs.weights import CloneWeightMatrix +from policyengine_us_data.build_outputs.worker_inputs import WorkerCalibrationInputs from tests.integration.build_outputs.fixtures import ( build_request, seed_local_h5_artifacts, @@ -38,42 +40,51 @@ def _run_worker( validate: bool = False, target_config: Path | None = None, validation_config: Path | None = None, + run_id: str = "tiny-worker-run", + scope: str = "regional", + scope_fingerprint: str | None = None, + artifacts_dir: Path | None = None, + return_process: bool = False, ) -> dict: _require_worker_dependencies() if not isinstance(requests, (list, tuple)): requests = (requests,) + worker_inputs = WorkerCalibrationInputs( + weights_path=artifacts.weights_path, + dataset_path=artifacts.dataset_path, + database_path=artifacts.db_path, + geography_path=artifacts.geography_path if use_saved_geography else None, + calibration_package_path=( + artifacts.calibration_package_path if use_package_geography else None + ), + run_config_path=artifacts.run_config_path, + n_clones=artifacts.n_clones, + seed=42, + ) cmd = [ sys.executable, "-m", "modal_app.worker_script", "--requests-json", json.dumps([request.to_dict() for request in requests]), - "--weights-path", - str(artifacts.weights_path), - "--dataset-path", - str(artifacts.dataset_path), - "--db-path", - str(artifacts.db_path), + *worker_inputs.to_worker_cli_args(), "--output-dir", str(output_dir), - "--n-clones", - str(artifacts.n_clones), + "--scope", + scope, + "--run-id", + run_id, ] + if artifacts_dir is not None: + cmd.extend(["--artifacts-dir", str(artifacts_dir)]) + if scope_fingerprint is not None: + cmd.extend(["--scope-fingerprint", scope_fingerprint]) if not validate: cmd.append("--no-validate") if target_config is not None: cmd.extend(["--target-config", str(target_config)]) if validation_config is not None: cmd.extend(["--validation-config", str(validation_config)]) - if use_saved_geography: - cmd.extend(["--geography-path", str(artifacts.geography_path)]) - if use_package_geography: - cmd.extend( - [ - "--calibration-package-path", - str(artifacts.calibration_package_path), - ] - ) result = subprocess.run( cmd, @@ -81,6 +92,8 @@ def _run_worker( text=True, check=True, ) + if return_process: + return result return json.loads(result.stdout) @@ -158,6 +171,7 @@ def test_worker_builds_national_h5_from_package_geography(tmp_path): artifacts=artifacts, output_dir=output_dir, use_package_geography=True, + scope="national", ) assert result["failed"] == [] @@ -191,21 +205,65 @@ def test_worker_validation_runs_for_tiny_district_state_and_national_h5s(tmp_pat validate=True, target_config=target_config, validation_config=validation_config, + return_process=True, ) + parsed = json.loads(result.stdout) - assert result["failed"] == [] - assert result["errors"] == [] - assert result["completed"] == ["district:NC-01", "state:NC", "national:US"] - assert len(result["validation_rows"]) == 3 - assert set(result["validation_summary"]) == { + assert result.stderr.count("Worker session ready:") == 1 + assert parsed["failed"] == [] + assert parsed["errors"] == [] + assert parsed["completed"] == ["district:NC-01", "state:NC", "national:US"] + assert len(parsed["validation_rows"]) == 3 + assert set(parsed["validation_summary"]) == { "district:NC-01", "state:NC", "national:US", } - for summary in result["validation_summary"].values(): + for summary in parsed["validation_summary"].values(): assert summary["n_targets"] == 1 assert summary["n_sanity_fail"] == 0 - for row in result["validation_rows"]: + for row in parsed["validation_rows"]: assert row["variable"] == "household_count" assert row["sanity_check"] == "PASS" assert row["in_training"] is True + + +def test_worker_consumes_scope_bootstrap_when_available(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path / "bootstrap") + request = build_request("district", geography=artifacts.geography) + output_dir = tmp_path / "bootstrap-out" + artifacts_dir = tmp_path / "pipeline-artifacts" / "run-123" + inputs = WorkerCalibrationInputs( + weights_path=artifacts.weights_path, + dataset_path=artifacts.dataset_path, + database_path=artifacts.db_path, + geography_path=artifacts.geography_path, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + n_clones=artifacts.n_clones, + seed=42, + ).to_publishing_input_bundle(run_id="run-123", version="0.0.0") + WorkerBootstrapBuilder().build( + inputs=inputs, + scope="regional", + artifacts_dir=artifacts_dir, + scope_fingerprint="regional-fingerprint", + ) + + result = _run_worker( + requests=request, + artifacts=artifacts, + output_dir=output_dir, + use_saved_geography=True, + use_package_geography=True, + run_id="run-123", + scope_fingerprint="regional-fingerprint", + artifacts_dir=artifacts_dir, + return_process=True, + ) + parsed = json.loads(result.stdout) + + assert "Worker session ready: scope=regional, bootstrap=used" in result.stderr + assert parsed["failed"] == [] + assert parsed["errors"] == [] + assert parsed["completed"] == [f"district:{request.area_id}"] diff --git a/tests/integration/support/tiny_h5.py b/tests/integration/support/tiny_h5.py index f938017c7..9ae3879fe 100644 --- a/tests/integration/support/tiny_h5.py +++ b/tests/integration/support/tiny_h5.py @@ -23,6 +23,9 @@ AreaBuildRequest, AreaFilter, ) +from policyengine_us_data.build_outputs.worker_inputs import ( + WorkerCalibrationInputs, +) from tests.integration.support.pipeline_workspace import TinyPipelineWorkspace from tests.integration.support.tiny_pipeline import TinyPipelineArtifacts @@ -167,20 +170,22 @@ def build_publishing_input_bundle( ) -> PublishingInputBundle: """Build the same traceability input shape used by local H5 publication.""" - return PublishingInputBundle( + worker_inputs = WorkerCalibrationInputs( weights_path=artifacts.weights_path, - source_dataset_path=artifacts.dataset_path, - target_db_path=artifacts.db_path, - exact_geography_path=artifacts.geography_path, + dataset_path=artifacts.dataset_path, + database_path=artifacts.db_path, + geography_path=artifacts.geography_path, calibration_package_path=( artifacts.calibration_package_path if scope == "regional" else None ), run_config_path=artifacts.run_config_path, - run_id=run_id, - version=VERSION, n_clones=artifacts.n_clones, seed=SEED, ) + return worker_inputs.to_publishing_input_bundle( + run_id=run_id, + version=VERSION, + ) def run_local_h5_worker( @@ -193,33 +198,30 @@ def run_local_h5_worker( ) -> dict: """Run the real local H5 worker subprocess for tiny fixture requests.""" + worker_inputs = WorkerCalibrationInputs( + weights_path=artifacts.weights_path, + dataset_path=artifacts.dataset_path, + database_path=artifacts.db_path, + geography_path=artifacts.geography_path if use_saved_geography else None, + calibration_package_path=( + artifacts.calibration_package_path if use_package_geography else None + ), + n_clones=artifacts.n_clones, + seed=SEED, + ) cmd = [ sys.executable, "-m", "modal_app.worker_script", "--requests-json", json.dumps([request.to_dict() for request in requests]), - "--weights-path", - str(artifacts.weights_path), - "--dataset-path", - str(artifacts.dataset_path), - "--db-path", - str(artifacts.db_path), + *worker_inputs.to_worker_cli_args(), "--output-dir", str(output_dir), - "--n-clones", - str(artifacts.n_clones), + "--scope", + "regional", "--no-validate", ] - if use_saved_geography: - cmd.extend(["--geography-path", str(artifacts.geography_path)]) - if use_package_geography: - cmd.extend( - [ - "--calibration-package-path", - str(artifacts.calibration_package_path), - ] - ) result = subprocess.run( cmd, diff --git a/tests/integration/test_tiny_h5_pipeline.py b/tests/integration/test_tiny_h5_pipeline.py index 5fdd81024..372d65553 100644 --- a/tests/integration/test_tiny_h5_pipeline.py +++ b/tests/integration/test_tiny_h5_pipeline.py @@ -83,21 +83,32 @@ def test_saved_geography_h5_pipeline_builds_regional_and_national_outputs(): assert preflight_result["geography_source"] == "saved_geography" assert len(preflight_result["fingerprint"]) == 16 - build_result = build.remote( + regional_result = build.remote( + branch="main", + run_id=run_id, + scope="regional", + work_items=_work_items("district", "state"), + calibration_inputs=preflight_result["calibration_inputs"], + validate=False, + ) + national_result = build.remote( branch="main", run_id=run_id, - work_items=_work_items("district", "state", "national"), + scope="national", + work_items=_work_items("national"), calibration_inputs=preflight_result["calibration_inputs"], validate=False, ) - assert build_result["failed"] == [] - assert build_result["errors"] == [] - assert build_result["completed"] == [ + assert regional_result["failed"] == [] + assert regional_result["errors"] == [] + assert regional_result["completed"] == [ "district:NC-01", "state:NC", - "national:US", ] + assert national_result["failed"] == [] + assert national_result["errors"] == [] + assert national_result["completed"] == ["national:US"] manifest = validate.remote(branch="main", run_id=run_id, version="0.0.0") assert manifest["totals"]["districts"] == 1 @@ -139,6 +150,7 @@ def test_package_fallback_h5_pipeline_builds_district_output(): build_result = build.remote( branch="main", run_id=run_id, + scope="regional", work_items=_work_items("district"), calibration_inputs=preflight_result["calibration_inputs"], validate=False, @@ -176,6 +188,7 @@ def test_missing_geography_h5_pipeline_fails_clearly(): build_result = build.remote( branch="main", run_id=run_id, + scope="regional", work_items=_work_items("district"), calibration_inputs=preflight_result["calibration_inputs"], validate=False, diff --git a/tests/support/modal_local_area.py b/tests/support/modal_local_area.py index 9e32d9e96..0e5f2a426 100644 --- a/tests/support/modal_local_area.py +++ b/tests/support/modal_local_area.py @@ -83,6 +83,9 @@ def decorator(func): fake_fingerprinting = ModuleType( "policyengine_us_data.build_outputs.fingerprinting" ) + fake_worker_inputs = ModuleType( + "policyengine_us_data.build_outputs.worker_inputs" + ) fake_policyengine.__path__ = [] fake_calibration.__path__ = [] fake_build_outputs.__path__ = [] @@ -114,6 +117,99 @@ def build(self, *args, **kwargs): fake_bootstrap.WorkerBootstrapBuilder = _FakeWorkerBootstrapBuilder fake_fingerprinting.PublishingInputBundle = object + class _FakeWorkerCalibrationInputs: + def __init__( + self, + *, + weights_path, + dataset_path, + database_path, + geography_path=None, + calibration_package_path=None, + run_config_path=None, + n_clones=430, + seed=42, + ): + self.weights_path = weights_path + self.dataset_path = dataset_path + self.database_path = database_path + self.geography_path = geography_path + self.calibration_package_path = calibration_package_path + self.run_config_path = run_config_path + self.n_clones = n_clones + self.seed = seed + + @classmethod + def from_artifact_paths(cls, **kwargs): + for key in ( + "geography_path", + "calibration_package_path", + "run_config_path", + ): + path = kwargs.get(key) + if path is not None and not path.exists(): + kwargs[key] = None + return cls(**kwargs) + + @classmethod + def from_wire_dict(cls, payload): + if isinstance(payload, cls): + return payload + return cls( + weights_path=payload["weights"], + dataset_path=payload["dataset"], + database_path=payload["database"], + geography_path=payload.get("geography"), + calibration_package_path=payload.get("calibration_package"), + run_config_path=payload.get("run_config"), + n_clones=payload.get("n_clones", 430), + seed=payload.get("seed", 42), + ) + + def to_worker_cli_args(self): + args = [ + "--weights-path", + str(self.weights_path), + "--dataset-path", + str(self.dataset_path), + "--db-path", + str(self.database_path), + "--n-clones", + str(self.n_clones), + "--seed", + str(self.seed), + ] + if self.geography_path is not None: + args.extend(["--geography-path", str(self.geography_path)]) + if self.calibration_package_path is not None: + args.extend( + [ + "--calibration-package-path", + str(self.calibration_package_path), + ] + ) + if self.run_config_path is not None: + args.extend(["--run-config-path", str(self.run_config_path)]) + return args + + def to_wire_dict(self): + payload = { + "weights": str(self.weights_path), + "dataset": str(self.dataset_path), + "database": str(self.database_path), + "n_clones": self.n_clones, + "seed": self.seed, + } + if self.geography_path is not None: + payload["geography"] = str(self.geography_path) + if self.calibration_package_path is not None: + payload["calibration_package"] = str(self.calibration_package_path) + if self.run_config_path is not None: + payload["run_config"] = str(self.run_config_path) + return payload + + fake_worker_inputs.WorkerCalibrationInputs = _FakeWorkerCalibrationInputs + class _FakeFingerprintingService: def build_traceability(self, *args, **kwargs): return object() @@ -136,6 +232,9 @@ def compute_scope_fingerprint(self, *args, **kwargs): fake_fingerprinting ), "policyengine_us_data.build_outputs.partitioning": (fake_partitioning), + "policyengine_us_data.build_outputs.worker_inputs": ( + fake_worker_inputs + ), } ) diff --git a/tests/unit/build_outputs/test_validation.py b/tests/unit/build_outputs/test_validation.py new file mode 100644 index 000000000..3455ad96e --- /dev/null +++ b/tests/unit/build_outputs/test_validation.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pandas as pd + +from policyengine_us_data.build_outputs.fingerprinting import PublishingInputBundle +from policyengine_us_data.build_outputs.validation import ( + AreaValidationService, + ValidationContext, + ValidationPolicy, +) + + +def _inputs(tmp_path: Path, *, with_db: bool = True) -> PublishingInputBundle: + return PublishingInputBundle( + weights_path=tmp_path / "weights.npy", + source_dataset_path=tmp_path / "source.h5", + target_db_path=tmp_path / "policy_data.db" if with_db else None, + exact_geography_path=tmp_path / "geography.npz", + calibration_package_path=None, + run_config_path=None, + run_id="run-123", + version="0.0.0", + n_clones=2, + seed=42, + ) + + +def test_validation_service_returns_none_when_disabled(tmp_path): + service = AreaValidationService() + + context = service.prepare_context( + inputs=_inputs(tmp_path), + policy=ValidationPolicy(enabled=False), + period=2024, + ) + + assert context is None + + +def test_validation_service_returns_empty_context_without_db_path(tmp_path): + service = AreaValidationService() + + context = service.prepare_context( + inputs=_inputs(tmp_path, with_db=False), + policy=ValidationPolicy(), + period=2024, + target_config_path=tmp_path / "target_config.yaml", + ) + + assert isinstance(context, ValidationContext) + assert context.target_db_path is None + assert context.validation_targets is None + assert context.training_mask is None + assert context.target_config_path == tmp_path / "target_config.yaml" + + +def test_validation_service_prepares_targets_training_mask_and_constraints(tmp_path): + engine_urls = [] + constraint_calls = [] + disposed = [] + target_config = tmp_path / "target_config.yaml" + validation_config = tmp_path / "validation_config.yaml" + + class FakeEngine: + def dispose(self): + disposed.append(True) + + def engine_factory(url: str): + engine_urls.append(url) + return FakeEngine() + + def query_targets(engine, period: int): + assert period == 2024 + return pd.DataFrame( + { + "variable": ["household_count", "income", "rent"], + "stratum_id": [1, 2, 3], + "geo_level": ["state", "state", "state"], + "geographic_id": ["37", "37", "37"], + } + ) + + def batch_constraints(engine, stratum_ids: list[int]): + constraint_calls.append(tuple(stratum_ids)) + return {stratum_id: [f"constraint-{stratum_id}"] for stratum_id in stratum_ids} + + def load_config(path: Path | str): + if Path(path) == validation_config: + return {"exclude": [{"variable": "rent"}]} + if Path(path) == target_config: + return {"include": [{"variable": "income"}]} + return {} + + def match_rules(targets, rules): + variables = {rule["variable"] for rule in rules} + return targets["variable"].isin(variables).to_numpy() + + service = AreaValidationService( + engine_factory=engine_factory, + query_targets=query_targets, + batch_constraints=batch_constraints, + load_target_config=load_config, + match_rules=match_rules, + ) + + context = service.prepare_context( + inputs=_inputs(tmp_path), + policy=ValidationPolicy(), + period=2024, + target_config_path=target_config, + validation_config_path=validation_config, + ) + + assert engine_urls == [f"sqlite:///{tmp_path / 'policy_data.db'}"] + assert disposed == [True] + assert context.validation_targets["variable"].tolist() == [ + "household_count", + "income", + ] + assert np.array_equal(context.training_mask, np.array([False, True])) + assert context.constraints_map == { + 1: ["constraint-1"], + 2: ["constraint-2"], + } + assert constraint_calls == [(1, 2)] + + +def test_validation_service_validates_one_area_from_prepared_context(): + calls = [] + targets = pd.DataFrame( + { + "variable": ["household_count", "state_income", "national_income"], + "stratum_id": [1, 2, 3], + "geo_level": ["district", "state", "national"], + "geographic_id": ["3701", "37", "US"], + } + ) + context = ValidationContext( + policy=ValidationPolicy(), + target_db_path=Path("/tmp/policy_data.db"), + period=2024, + validation_targets=targets, + training_mask=np.array([True, False, True]), + constraints_map={1: ["constraint-1"], 2: ["constraint-2"], 3: ["constraint-3"]}, + ) + request = SimpleNamespace( + area_type="district", + area_id="NC-01", + display_name="NC-01", + validation_geo_level="district", + validation_geographic_ids=("3701",), + ) + + def validate_h5(**kwargs): + calls.append(kwargs) + return [ + {"sanity_check": "PASS", "rel_abs_error": 0.25}, + {"sanity_check": "FAIL", "rel_abs_error": float("inf")}, + ] + + result = AreaValidationService(validate_h5=validate_h5).validate_request( + context=context, + h5_path=Path("/tmp/NC-01.h5"), + request=request, + ) + + assert len(calls) == 1 + assert calls[0]["h5_path"] == "/tmp/NC-01.h5" + assert calls[0]["area_type"] == "districts" + assert calls[0]["area_targets"]["variable"].tolist() == ["household_count"] + assert np.array_equal(calls[0]["area_training"], np.array([True])) + assert calls[0]["constraints_map"] == {1: ["constraint-1"]} + assert calls[0]["db_path"] == "/tmp/policy_data.db" + assert calls[0]["period"] == 2024 + assert result.summary == { + "n_targets": 2, + "n_sanity_fail": 1, + "mean_rel_abs_error": 0.25, + } + + +def test_validation_service_filters_national_targets_without_geographic_id(): + calls = [] + targets = pd.DataFrame( + { + "variable": ["state_income", "national_income"], + "stratum_id": [2, 3], + "geo_level": ["state", "national"], + "geographic_id": ["37", "US"], + } + ) + context = ValidationContext( + policy=ValidationPolicy(), + target_db_path=Path("/tmp/policy_data.db"), + period=2024, + validation_targets=targets, + training_mask=np.array([False, True]), + constraints_map={2: ["constraint-2"], 3: ["constraint-3"]}, + ) + request = SimpleNamespace( + area_type="national", + area_id="US", + display_name="US", + validation_geo_level="national", + validation_geographic_ids=("ignored",), + ) + + def validate_h5(**kwargs): + calls.append(kwargs) + return [{"sanity_check": "PASS", "rel_abs_error": 0.0}] + + result = AreaValidationService(validate_h5=validate_h5).validate_request( + context=context, + h5_path=Path("/tmp/US.h5"), + request=request, + ) + + assert calls[0]["area_type"] == "national" + assert calls[0]["area_targets"]["variable"].tolist() == ["national_income"] + assert np.array_equal(calls[0]["area_training"], np.array([True])) + assert calls[0]["constraints_map"] == {3: ["constraint-3"]} + assert result.summary["n_targets"] == 1 + + +def test_validation_service_returns_empty_result_for_unmatched_area(): + called = False + context = ValidationContext( + policy=ValidationPolicy(), + target_db_path=Path("/tmp/policy_data.db"), + period=2024, + validation_targets=pd.DataFrame( + { + "variable": ["state_income"], + "stratum_id": [2], + "geo_level": ["state"], + "geographic_id": ["37"], + } + ), + training_mask=np.array([True]), + constraints_map={2: ["constraint-2"]}, + ) + request = SimpleNamespace( + area_type="district", + area_id="NC-01", + display_name="NC-01", + validation_geo_level="district", + validation_geographic_ids=("3701",), + ) + + def validate_h5(**kwargs): + nonlocal called + called = True + return [] + + result = AreaValidationService(validate_h5=validate_h5).validate_request( + context=context, + h5_path=Path("/tmp/NC-01.h5"), + request=request, + ) + + assert called is False + assert result.rows == () + assert result.summary == { + "n_targets": 0, + "n_sanity_fail": 0, + "mean_rel_abs_error": 0.0, + } diff --git a/tests/unit/build_outputs/test_worker_inputs.py b/tests/unit/build_outputs/test_worker_inputs.py new file mode 100644 index 000000000..3bdc6e9fd --- /dev/null +++ b/tests/unit/build_outputs/test_worker_inputs.py @@ -0,0 +1,146 @@ +from pathlib import Path + +import pytest + +from policyengine_us_data.build_outputs.worker_inputs import ( + WorkerCalibrationInputs, +) + + +def test_worker_calibration_inputs_round_trip_wire_payload(): + inputs = WorkerCalibrationInputs( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + database_path=Path("/tmp/policy_data.db"), + geography_path=Path("/tmp/geography_assignment.npz"), + calibration_package_path=Path("/tmp/calibration_package.pkl"), + run_config_path=Path("/tmp/unified_run_config.json"), + n_clones=4, + seed=123, + ) + + payload = inputs.to_wire_dict() + normalized = WorkerCalibrationInputs.from_wire_dict(payload) + + assert normalized == inputs + assert payload == { + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + "geography": "/tmp/geography_assignment.npz", + "calibration_package": "/tmp/calibration_package.pkl", + "run_config": "/tmp/unified_run_config.json", + "n_clones": 4, + "seed": 123, + } + + +def test_worker_calibration_inputs_defaults_legacy_modal_payload_values(): + inputs = WorkerCalibrationInputs.from_wire_dict( + { + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + } + ) + + assert inputs.n_clones == 430 + assert inputs.seed == 42 + + +def test_worker_calibration_inputs_build_worker_cli_args(): + inputs = WorkerCalibrationInputs( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + database_path=Path("/tmp/policy_data.db"), + geography_path=Path("/tmp/geography_assignment.npz"), + calibration_package_path=Path("/tmp/calibration_package.pkl"), + run_config_path=Path("/tmp/unified_run_config.json"), + n_clones=4, + seed=123, + ) + + assert inputs.to_worker_cli_args() == [ + "--weights-path", + "/tmp/calibration_weights.npy", + "--dataset-path", + "/tmp/source.h5", + "--db-path", + "/tmp/policy_data.db", + "--n-clones", + "4", + "--seed", + "123", + "--geography-path", + "/tmp/geography_assignment.npz", + "--calibration-package-path", + "/tmp/calibration_package.pkl", + "--run-config-path", + "/tmp/unified_run_config.json", + ] + + +def test_worker_calibration_inputs_build_publishing_input_bundle(): + inputs = WorkerCalibrationInputs( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + database_path=Path("/tmp/policy_data.db"), + geography_path=Path("/tmp/geography_assignment.npz"), + calibration_package_path=Path("/tmp/calibration_package.pkl"), + run_config_path=Path("/tmp/unified_run_config.json"), + n_clones=4, + seed=123, + ) + + bundle = inputs.to_publishing_input_bundle( + run_id="run-123", + version="1.2.3", + legacy_blocks_path=Path("/tmp/stacked_blocks.npy"), + ) + + assert bundle.weights_path == inputs.weights_path + assert bundle.source_dataset_path == inputs.dataset_path + assert bundle.target_db_path == inputs.database_path + assert bundle.exact_geography_path == inputs.geography_path + assert bundle.calibration_package_path == inputs.calibration_package_path + assert bundle.run_config_path == inputs.run_config_path + assert bundle.run_id == "run-123" + assert bundle.version == "1.2.3" + assert bundle.n_clones == 4 + assert bundle.seed == 123 + assert bundle.legacy_blocks_path == Path("/tmp/stacked_blocks.npy") + + +def test_worker_calibration_inputs_omit_missing_optional_artifact_paths(tmp_path): + run_config_path = tmp_path / "unified_run_config.json" + run_config_path.write_text("{}") + + inputs = WorkerCalibrationInputs.from_artifact_paths( + weights_path=tmp_path / "calibration_weights.npy", + dataset_path=tmp_path / "source.h5", + database_path=tmp_path / "policy_data.db", + geography_path=tmp_path / "missing_geography_assignment.npz", + calibration_package_path=tmp_path / "missing_calibration_package.pkl", + run_config_path=run_config_path, + n_clones=4, + seed=123, + ) + + payload = inputs.to_wire_dict() + + assert inputs.geography_path is None + assert inputs.calibration_package_path is None + assert inputs.run_config_path == run_config_path + assert "geography" not in payload + assert "calibration_package" not in payload + assert payload["run_config"] == str(run_config_path) + + +def test_worker_calibration_inputs_reject_missing_required_paths(): + with pytest.raises(KeyError, match="weights"): + WorkerCalibrationInputs.from_wire_dict( + { + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + } + ) diff --git a/tests/unit/build_outputs/test_worker_session.py b/tests/unit/build_outputs/test_worker_session.py new file mode 100644 index 000000000..7ff142ddb --- /dev/null +++ b/tests/unit/build_outputs/test_worker_session.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pytest + +from policyengine_us_data.build_outputs.bootstrap import ( + WorkerBootstrapBuilder, + WorkerBootstrapStore, +) +from policyengine_us_data.build_outputs.validation import ( + ValidationContext, + ValidationPolicy, +) +from policyengine_us_data.build_outputs.worker_session import ( + WorkerSession, + WorkerSessionFactory, +) +from tests.support.build_outputs.bootstrap import ( + FakeDatasetReader, + FakeFingerprintingService, + FakeGeographyLoader, + make_bootstrap_test_artifacts, +) + + +class SessionDatasetReader(FakeDatasetReader): + """Dataset reader fake that records raw and bootstrap source loads.""" + + def __init__(self, snapshot, *, fail_with_entity_graph: bool = False): + super().__init__(snapshot) + self.loaded_with_entity_graph: list[tuple[Path, object]] = [] + self.fail_with_entity_graph = fail_with_entity_graph + + def load_with_entity_graph(self, dataset_path: Path, entity_graph): + self.loaded_with_entity_graph.append((Path(dataset_path), entity_graph)) + if self.fail_with_entity_graph: + raise RuntimeError("entity graph load failed") + return self.snapshot + + +class SessionGeographyLoader(FakeGeographyLoader): + """Geography loader fake that records load calls.""" + + def __init__(self, artifacts): + super().__init__(artifacts) + self.load_calls = [] + + def load(self, **kwargs): + self.load_calls.append(kwargs) + return super().load(**kwargs) + + +class FakeValidationService: + """Validation service fake returning a prepared context.""" + + def __init__(self): + self.calls = [] + + def prepare_context(self, **kwargs): + self.calls.append(kwargs) + policy = kwargs["policy"] + if not policy.enabled: + return None + return ValidationContext( + policy=policy, + target_db_path=kwargs["inputs"].target_db_path, + period=kwargs["period"], + validation_targets=SimpleNamespace(name="targets"), + training_mask=np.array([True]), + constraints_map={1: []}, + ) + + +def test_worker_session_caches_are_per_session(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + + first = WorkerSession( + inputs=artifacts.inputs, + scope="regional", + source=artifacts.snapshot, + weights=SimpleNamespace(values=np.ones(2), n_records=1, n_clones=2), + geography=SimpleNamespace(n_records=1, n_clones=2), + ) + second = WorkerSession( + inputs=artifacts.inputs, + scope="regional", + source=artifacts.snapshot, + weights=SimpleNamespace(values=np.ones(2), n_records=1, n_clones=2), + geography=SimpleNamespace(n_records=1, n_clones=2), + ) + + first.caches["marker"] = "first" + + assert second.caches == {} + + +def test_worker_session_factory_uses_raw_loaders_without_bootstrap(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + dataset_reader = SessionDatasetReader(artifacts.snapshot) + geography_loader = SessionGeographyLoader(artifacts) + validation_service = FakeValidationService() + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=geography_loader, + validation_service=validation_service, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "unavailable" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert dataset_reader.loaded_with_entity_graph == [] + assert session.weights.n_records == artifacts.n_records + assert session.weights.n_clones == artifacts.n_clones + assert geography_loader.load_calls[0]["n_records"] == artifacts.n_records + assert validation_service.calls[0]["inputs"] == artifacts.inputs + + +def test_worker_session_factory_prefers_bootstrap_entity_graph(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + expected_scope_fingerprint="regional-fingerprint", + ) + + assert session.bootstrap_status == "used" + assert session.bootstrap_bundle is not None + assert dataset_reader.loaded_paths == [] + assert dataset_reader.loaded_with_entity_graph[0][0] == ( + artifacts.inputs.source_dataset_path + ) + + +def test_worker_session_factory_requires_expected_fingerprint_for_bootstrap( + tmp_path, +): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "expected scope fingerprint" in session.caches["bootstrap_error"] + + +def test_worker_session_factory_falls_back_when_bootstrap_source_load_fails(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + dataset_reader = SessionDatasetReader( + artifacts.snapshot, + fail_with_entity_graph=True, + ) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + expected_scope_fingerprint="regional-fingerprint", + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "entity graph load failed" in session.caches["bootstrap_error"] + + +def test_worker_session_factory_falls_back_when_bootstrap_inputs_mismatch(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + changed_inputs = type(artifacts.inputs)( + weights_path=artifacts.inputs.weights_path, + source_dataset_path=artifacts.inputs.source_dataset_path, + target_db_path=artifacts.inputs.target_db_path, + exact_geography_path=artifacts.inputs.exact_geography_path, + calibration_package_path=artifacts.inputs.calibration_package_path, + run_config_path=artifacts.inputs.run_config_path, + run_id="different-run", + version=artifacts.inputs.version, + n_clones=artifacts.inputs.n_clones, + seed=artifacts.inputs.seed, + legacy_blocks_path=artifacts.inputs.legacy_blocks_path, + ) + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=changed_inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "does not match worker run_id" in session.caches["bootstrap_error"] + + +def test_worker_session_factory_falls_back_when_bootstrap_fingerprint_mismatch( + tmp_path, +): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + expected_scope_fingerprint="changed-fingerprint", + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "does not match expected fingerprint" in session.caches["bootstrap_error"] + + +def test_worker_session_factory_marks_corrupt_bootstrap_as_fallback(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + store.scope_dir("regional").mkdir(parents=True) + store.manifest_path("regional").write_text("{not-json") + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "Expecting property name" in session.caches["bootstrap_error"] + + +def test_worker_session_factory_rejects_weight_clone_mismatch(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs", n_clones=2) + bad_inputs = type(artifacts.inputs)( + weights_path=artifacts.inputs.weights_path, + source_dataset_path=artifacts.inputs.source_dataset_path, + target_db_path=artifacts.inputs.target_db_path, + exact_geography_path=artifacts.inputs.exact_geography_path, + calibration_package_path=artifacts.inputs.calibration_package_path, + run_config_path=artifacts.inputs.run_config_path, + run_id=artifacts.inputs.run_id, + version=artifacts.inputs.version, + n_clones=3, + seed=artifacts.inputs.seed, + legacy_blocks_path=artifacts.inputs.legacy_blocks_path, + ) + + with pytest.raises(ValueError, match="expected 3"): + WorkerSessionFactory( + dataset_reader=SessionDatasetReader(artifacts.snapshot), + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + ).create( + inputs=bad_inputs, + scope="regional", + validation_policy=ValidationPolicy(enabled=False), + period=2024, + ) diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index 08ab03344..da847fd24 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -328,3 +328,102 @@ def build(self, **kwargs): assert bundle.manifest_path == ( artifacts_dir / "bootstrap" / "regional" / "worker_bootstrap.json" ) + + +def test_build_worker_calibration_inputs_includes_existing_run_config_and_package( + tmp_path, +): + local_area = load_local_area_module() + run_config_path = tmp_path / "unified_run_config.json" + package_path = tmp_path / "calibration_package.pkl" + run_config_path.write_text("{}") + package_path.write_bytes(b"package") + + inputs = local_area._build_worker_calibration_inputs( + weights_path=tmp_path / "calibration_weights.npy", + geography_path=tmp_path / "geography_assignment.npz", + dataset_path=tmp_path / "source.h5", + db_path=tmp_path / "policy_data.db", + n_clones=430, + seed=42, + run_config_path=run_config_path, + calibration_package_path=package_path, + ) + + assert inputs.run_config_path == run_config_path + assert inputs.calibration_package_path == package_path + assert inputs.n_clones == 430 + assert inputs.seed == 42 + assert inputs.to_wire_dict()["run_config"] == str(run_config_path) + + +def test_build_worker_calibration_inputs_omits_missing_optional_files(tmp_path): + local_area = load_local_area_module() + + inputs = local_area._build_worker_calibration_inputs( + weights_path=tmp_path / "national_calibration_weights.npy", + geography_path=tmp_path / "national_geography_assignment.npz", + dataset_path=tmp_path / "source.h5", + db_path=tmp_path / "policy_data.db", + n_clones=430, + seed=42, + run_config_path=tmp_path / "missing_config.json", + calibration_package_path=tmp_path / "missing_package.pkl", + ) + + assert inputs.run_config_path is None + assert inputs.calibration_package_path is None + assert "run_config" not in inputs.to_wire_dict() + assert "calibration_package" not in inputs.to_wire_dict() + + +def test_build_areas_worker_surfaces_successful_worker_stderr( + monkeypatch, + capsys, + tmp_path, +): + local_area = load_local_area_module() + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda branch: None) + monkeypatch.setattr(local_area, "VOLUME_MOUNT", str(tmp_path / "staging")) + monkeypatch.setattr( + local_area, + "pipeline_volume", + SimpleNamespace(reload=lambda: None), + ) + monkeypatch.setattr( + local_area, + "staging_volume", + SimpleNamespace(reload=lambda: None, commit=lambda: None), + ) + captured_cmd = {} + + def fake_run(cmd, **kwargs): + captured_cmd["cmd"] = cmd + return SimpleNamespace( + returncode=0, + stdout='{"completed": ["district:NC-01"], "failed": [], "errors": []}', + stderr="Worker session ready: scope=regional, bootstrap=used\n", + ) + + monkeypatch.setattr(local_area.subprocess, "run", fake_run) + + result = local_area.build_areas_worker( + branch="main", + run_id="run-123", + scope="regional", + work_items=[{"type": "district", "id": "NC-01"}], + calibration_inputs={ + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + }, + validate=False, + scope_fingerprint="regional-fingerprint", + ) + + captured = capsys.readouterr() + assert result["completed"] == ["district:NC-01"] + assert "Worker session ready: scope=regional, bootstrap=used" in captured.err + assert "--scope-fingerprint" in captured_cmd["cmd"] + assert "regional-fingerprint" in captured_cmd["cmd"] diff --git a/tests/unit/test_modal_worker_script.py b/tests/unit/test_modal_worker_script.py index d2e90f46b..2ce7d26e7 100644 --- a/tests/unit/test_modal_worker_script.py +++ b/tests/unit/test_modal_worker_script.py @@ -25,6 +25,8 @@ def test_parse_args_accepts_requests_json(): "/tmp/policy_data.db", "--output-dir", "/tmp/out", + "--scope", + "regional", ] ) @@ -45,6 +47,8 @@ def test_parse_args_accepts_calibration_package_path(): "/tmp/policy_data.db", "--output-dir", "/tmp/out", + "--scope", + "regional", "--calibration-package-path", "/tmp/calibration_package.pkl", ] @@ -53,6 +57,40 @@ def test_parse_args_accepts_calibration_package_path(): assert args.calibration_package_path == "/tmp/calibration_package.pkl" +def test_parse_args_accepts_worker_session_paths(): + args = worker_script.parse_args( + [ + "--requests-json", + "[]", + "--weights-path", + "/tmp/weights.npy", + "--dataset-path", + "/tmp/source.h5", + "--db-path", + "/tmp/policy_data.db", + "--output-dir", + "/tmp/out", + "--scope", + "national", + "--run-id", + "run-123", + "--artifacts-dir", + "/tmp/artifacts/run-123", + "--run-config-path", + "/tmp/unified_run_config.json", + "--scope-fingerprint", + "regional-fingerprint", + ] + ) + + assert args.run_id == "run-123" + assert args.scope == "national" + assert args.artifacts_dir == "/tmp/artifacts/run-123" + assert args.run_config_path == "/tmp/unified_run_config.json" + assert args.scope_fingerprint == "regional-fingerprint" + assert not hasattr(args, "version") + + def test_load_request_inputs_from_args_uses_request_payloads_when_present(): args = SimpleNamespace( requests_json=json.dumps([{"area_type": "national", "area_id": "US"}]), @@ -164,3 +202,21 @@ def test_resolve_output_path_rejects_escaped_request_path(tmp_path): assert "must stay within the worker output_dir" in str(exc) else: raise AssertionError("Expected _resolve_output_path to reject traversal") + + +def test_log_worker_session_ready_includes_bootstrap_fallback_reason(capsys): + session = SimpleNamespace( + bootstrap_status="fallback", + caches={"bootstrap_error": "entity graph load failed"}, + ) + geography = SimpleNamespace(n_clones=2, n_records=10) + + worker_script._log_worker_session_ready( + scope="regional", + session=session, + geography=geography, + ) + + captured = capsys.readouterr() + assert "Worker session ready: scope=regional, bootstrap=fallback" in captured.err + assert "Worker bootstrap fallback reason: entity graph load failed" in captured.err