From 2091f50b4afb367c345006520e756519db3f9d06 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 14:58:29 +0200 Subject: [PATCH 01/17] Add local H5 worker session setup --- changelog.d/h5-worker-session.added | 1 + modal_app/local_area.py | 6 + modal_app/worker_script.py | 194 +++++++------ .../build_outputs/__init__.py | 3 +- .../build_outputs/source_dataset.py | 29 ++ .../build_outputs/validation.py | 267 ++++++++++++++++++ .../build_outputs/worker_session.py | 245 ++++++++++++++++ .../test_worker_script_tiny_fixture.py | 70 ++++- tests/unit/build_outputs/test_validation.py | 128 +++++++++ .../unit/build_outputs/test_worker_session.py | 247 ++++++++++++++++ tests/unit/test_modal_worker_script.py | 54 ++++ 11 files changed, 1148 insertions(+), 96 deletions(-) create mode 100644 changelog.d/h5-worker-session.added create mode 100644 policyengine_us_data/build_outputs/validation.py create mode 100644 policyengine_us_data/build_outputs/worker_session.py create mode 100644 tests/unit/build_outputs/test_validation.py create mode 100644 tests/unit/build_outputs/test_worker_session.py diff --git a/changelog.d/h5-worker-session.added b/changelog.d/h5-worker-session.added new file mode 100644 index 000000000..13dde67e6 --- /dev/null +++ b/changelog.d/h5-worker-session.added @@ -0,0 +1 @@ +Add worker-scoped setup and validation context contracts for local H5 builds. diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 2b36b3cd2..92dbb5dc1 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -715,6 +715,10 @@ def build_areas_worker( calibration_inputs["database"], "--output-dir", str(output_dir), + "--run-id", + run_id, + "--artifacts-dir", + str(Path("/pipeline/artifacts") / run_id), ] if "geography" in calibration_inputs: worker_cmd.extend(["--geography-path", calibration_inputs["geography"]]) @@ -729,6 +733,8 @@ def build_areas_worker( worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])]) if "seed" in calibration_inputs: worker_cmd.extend(["--seed", str(calibration_inputs["seed"])]) + if "run_config" in calibration_inputs: + worker_cmd.extend(["--run-config-path", calibration_inputs["run_config"]]) repo_root = Path("/root/policyengine-us-data") cal_dir = repo_root / "policyengine_us_data" / "calibration" worker_cmd.extend( diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index 06f96f7e2..f76f34efc 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -12,8 +12,6 @@ from pathlib import Path from typing import Any -import numpy as np - def _validate_in_subprocess( h5_path, @@ -149,6 +147,26 @@ def parse_args(argv: list[str] | None = None): parser.add_argument("--dataset-path", required=True) parser.add_argument("--db-path", required=True) parser.add_argument("--output-dir", required=True) + parser.add_argument( + "--run-id", + default=None, + help="Pipeline run ID used for traceability and bootstrap lookup", + ) + parser.add_argument( + "--version", + default="0.0.0", + help="Package or release version associated with the worker run", + ) + parser.add_argument( + "--artifacts-dir", + default=None, + help="Optional run-scoped pipeline artifacts directory containing bootstrap artifacts", + ) + parser.add_argument( + "--run-config-path", + default=None, + help="Optional unified run configuration JSON used for traceability", + ) parser.add_argument( "--geography-path", default=None, @@ -212,6 +230,51 @@ def _load_request_inputs_from_args( return "work_items", tuple(json.loads(args.work_items)) +def _infer_worker_scope(request_input_mode: str, request_inputs) -> str: + """Infer which bootstrap scope matches the queued request set.""" + + if request_input_mode == "requests": + all_national = all( + getattr(request, "area_type", None) == "national" + for request in request_inputs + ) + else: + all_national = all( + isinstance(item, dict) and item.get("type") == "national" + for item in request_inputs + ) + return "national" if all_national else "regional" + + +def _build_publishing_inputs(*, args, run_id: str): + """Build the traceability input bundle consumed by worker setup services.""" + + from policyengine_us_data.build_outputs.fingerprinting import ( + PublishingInputBundle, + ) + + return PublishingInputBundle( + weights_path=Path(args.weights_path), + source_dataset_path=Path(args.dataset_path), + target_db_path=Path(args.db_path) if args.db_path else None, + exact_geography_path=( + Path(args.geography_path) if args.geography_path is not None else None + ), + calibration_package_path=( + Path(args.calibration_package_path) + if args.calibration_package_path is not None + else None + ), + run_config_path=( + Path(args.run_config_path) if args.run_config_path is not None else None + ), + run_id=run_id, + version=args.version, + n_clones=args.n_clones, + seed=args.seed, + ) + + def _build_kwargs_from_request(request) -> dict[str, Any]: """Translate a typed request into `build_h5(...)` keyword arguments.""" @@ -299,10 +362,9 @@ def _resolve_request_input( def main(argv: list[str] | None = None): args = parse_args(argv) - weights_path = Path(args.weights_path) dataset_path = Path(args.dataset_path) - db_path = Path(args.db_path) output_dir = Path(args.output_dir) + run_id = args.run_id or output_dir.name or "local-worker" from policyengine_us_data.utils.takeup import ( SIMPLE_TAKEUP_VARS, @@ -315,100 +377,56 @@ def main(argv: list[str] | None = None): from policyengine_us_data.calibration.publish_local_area import ( build_h5, - load_calibration_geography, ) from policyengine_us_data.build_outputs.area_catalog import USAreaCatalog from policyengine_us_data.build_outputs.requests import AreaBuildRequest + from policyengine_us_data.build_outputs.validation import ValidationPolicy + from policyengine_us_data.build_outputs.worker_session import WorkerSessionFactory - weights = np.load(weights_path) - - from policyengine_us import Microsimulation - - _sim = Microsimulation(dataset=str(dataset_path)) - n_records = len(_sim.calculate("household_id", map_to="household").values) - del _sim - - geography = load_calibration_geography( - weights_path=weights_path, - n_records=n_records, - n_clones=args.n_clones, - geography_path=( - Path(args.geography_path) if args.geography_path is not None else None - ), - calibration_package_path=( - Path(args.calibration_package_path) - if args.calibration_package_path is not None - else None - ), - ) - print( - f"Loaded geography: " - f"{geography.n_clones} clones x " - f"{geography.n_records} records", - file=sys.stderr, - ) area_catalog = USAreaCatalog.default() request_input_mode, request_inputs = _load_request_inputs_from_args( args=args, area_build_request_cls=AreaBuildRequest, ) - - # ── Validation setup (once per worker) ── - validation_targets = None - training_mask_full = None - constraints_map = None - if not args.no_validate: - from sqlalchemy import create_engine - from policyengine_us_data.calibration.validate_staging import ( - _query_all_active_targets, - _batch_stratum_constraints, - ) - from policyengine_us_data.calibration.unified_calibration import ( - load_target_config, - _match_rules, - ) - - engine = create_engine(f"sqlite:///{db_path}") - validation_targets = _query_all_active_targets(engine, args.period) - print( - f"Loaded {len(validation_targets)} validation targets", - file=sys.stderr, - ) - - # Apply exclude/include from validation config - if args.validation_config: - val_cfg = load_target_config(args.validation_config) - exc_rules = val_cfg.get("exclude", []) - if exc_rules: - exc_mask = _match_rules(validation_targets, exc_rules) - validation_targets = validation_targets[~exc_mask].reset_index( - drop=True - ) - inc_rules = val_cfg.get("include", []) - if inc_rules: - inc_mask = _match_rules(validation_targets, inc_rules) - validation_targets = validation_targets[inc_mask].reset_index(drop=True) - - # Compute training mask from training config - if args.target_config: - tr_cfg = load_target_config(args.target_config) - tr_inc = tr_cfg.get("include", []) - if tr_inc: - training_mask_full = np.asarray( - _match_rules(validation_targets, tr_inc), - dtype=bool, - ) - else: - training_mask_full = np.ones(len(validation_targets), dtype=bool) - else: - training_mask_full = np.ones(len(validation_targets), dtype=bool) - - # Batch-load constraints - stratum_ids = validation_targets["stratum_id"].unique().tolist() - constraints_map = _batch_stratum_constraints(engine, stratum_ids) + scope = _infer_worker_scope(request_input_mode, request_inputs) + inputs = _build_publishing_inputs(args=args, run_id=run_id) + + session = WorkerSessionFactory().create( + inputs=inputs, + scope=scope, + validation_policy=ValidationPolicy(enabled=not args.no_validate), + period=args.period, + target_config_path=Path(args.target_config) if args.target_config else None, + validation_config_path=( + Path(args.validation_config) if args.validation_config else None + ), + artifacts_dir=Path(args.artifacts_dir) if args.artifacts_dir else None, + ) + weights = session.weights.values + n_records = session.weights.n_records + geography = session.geography + validation_context = session.validation_context + validation_targets = ( + validation_context.validation_targets + if validation_context is not None + else None + ) + training_mask_full = ( + validation_context.training_mask if validation_context is not None else None + ) + constraints_map = ( + validation_context.constraints_map if validation_context is not None else None + ) + print( + "Worker session ready: " + f"scope={scope}, bootstrap={session.bootstrap_status}, " + f"{geography.n_clones} clones x {geography.n_records} records", + file=sys.stderr, + ) + if validation_targets is not None: print( f"Validation ready: {len(validation_targets)} targets, " - f"{len(stratum_ids)} strata", + f"{len(constraints_map or {})} strata", file=sys.stderr, ) @@ -486,7 +504,7 @@ def main(argv: list[str] | None = None): validation_targets=validation_targets, training_mask_full=training_mask_full, constraints_map=constraints_map, - db_path=str(db_path), + db_path=str(inputs.target_db_path), period=args.period, ) results["validation_rows"].extend(v_rows) diff --git a/policyengine_us_data/build_outputs/__init__.py b/policyengine_us_data/build_outputs/__init__.py index b2a67b265..e0eda9a09 100644 --- a/policyengine_us_data/build_outputs/__init__.py +++ b/policyengine_us_data/build_outputs/__init__.py @@ -4,5 +4,6 @@ seams rather than speculative placeholders. The current early slices support H5 output request construction, exact calibration geography loading, fingerprinting, clone-weight shape contracts, worker partitioning, source -dataset snapshot contracts, and introduced worker-bootstrap artifacts. +dataset snapshot contracts, introduced worker-bootstrap artifacts, and +worker-scoped session and validation context setup. """ diff --git a/policyengine_us_data/build_outputs/source_dataset.py b/policyengine_us_data/build_outputs/source_dataset.py index 3f1bd08d2..37d259189 100644 --- a/policyengine_us_data/build_outputs/source_dataset.py +++ b/policyengine_us_data/build_outputs/source_dataset.py @@ -556,3 +556,32 @@ def load(self, dataset_path: Path) -> SourceDatasetSnapshot: path = Path(dataset_path) simulation = Microsimulation(dataset=str(path)) return SourceDatasetSnapshot.from_simulation(path, simulation) + + def load_with_entity_graph( + self, + dataset_path: Path, + entity_graph: EntityGraph, + ) -> SourceDatasetSnapshot: + """Open a source H5 dataset using a prebuilt structural entity graph. + + Args: + dataset_path: Source H5 dataset path. + entity_graph: Persisted structural entity graph for the dataset. + + Returns: + A `SourceDatasetSnapshot` backed by a PolicyEngine microsimulation + and the supplied entity graph. + """ + + from policyengine_us import Microsimulation + + path = Path(dataset_path) + simulation = Microsimulation(dataset=str(path)) + provider = MicrosimulationVariableProvider(simulation) + return SourceDatasetSnapshot( + dataset_path=path, + time_period=int(simulation.default_calculation_period), + entity_graph=entity_graph, + input_variables=provider.input_variables, + variable_provider=provider, + ) diff --git a/policyengine_us_data/build_outputs/validation.py b/policyengine_us_data/build_outputs/validation.py new file mode 100644 index 000000000..95c5b229b --- /dev/null +++ b/policyengine_us_data/build_outputs/validation.py @@ -0,0 +1,267 @@ +"""Worker-scoped validation context for local H5 publication.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Mapping + +import numpy as np + +from policyengine_us_data.pipeline_metadata import pipeline_node + +from .fingerprinting import PublishingInputBundle + +__all__ = [ + "AreaValidationService", + "ValidationContext", + "ValidationPolicy", +] + + +@pipeline_node( + id="local_h5_validation_policy", + label="ValidationPolicy", + node_type="library", + description="Worker-scoped local H5 validation policy contract.", + source_file="policyengine_us_data/build_outputs/validation.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=["uv run pytest tests/unit/build_outputs/test_validation.py"], +) +@dataclass(frozen=True) +class ValidationPolicy: + """Validation switches for a local H5 worker session. + + The current worker uses `enabled`; the other flags make the policy shape + explicit before later migration slices move validation behavior out of the + legacy worker subprocess. + """ + + enabled: bool = True + fail_on_exception: bool = False + fail_on_validation_failure: bool = False + run_sanity_checks: bool = True + run_target_validation: bool = True + run_national_validation: bool = True + + +@pipeline_node( + id="local_h5_validation_context", + label="ValidationContext", + node_type="library", + description="Prepared per-worker local H5 validation target context.", + source_file="policyengine_us_data/build_outputs/validation.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=["uv run pytest tests/unit/build_outputs/test_validation.py"], +) +@dataclass(frozen=True) +class ValidationContext: + """Prepared validation data reused across all requests in one worker.""" + + policy: ValidationPolicy + target_db_path: Path | None + period: int + validation_targets: Any = None + training_mask: np.ndarray | None = None + constraints_map: Mapping[int, Any] | None = None + target_config_path: Path | None = None + validation_config_path: Path | None = None + + def __post_init__(self) -> None: + target_db_path = ( + Path(self.target_db_path) if self.target_db_path is not None else None + ) + target_config_path = ( + Path(self.target_config_path) + if self.target_config_path is not None + else None + ) + validation_config_path = ( + Path(self.validation_config_path) + if self.validation_config_path is not None + else None + ) + object.__setattr__(self, "target_db_path", target_db_path) + object.__setattr__(self, "period", int(self.period)) + object.__setattr__(self, "target_config_path", target_config_path) + object.__setattr__(self, "validation_config_path", validation_config_path) + if self.training_mask is not None: + object.__setattr__( + self, + "training_mask", + np.asarray(self.training_mask, dtype=bool), + ) + if self.constraints_map is not None: + object.__setattr__( + self, + "constraints_map", + {int(key): value for key, value in self.constraints_map.items()}, + ) + + +@pipeline_node( + id="local_h5_area_validation_service", + label="AreaValidationService", + node_type="library", + description="Prepare local H5 validation targets once per worker session.", + source_file="policyengine_us_data/build_outputs/validation.py", + status="current", + stability="moving", + pathways=["local_h5"], + artifacts_in=["policy_data.db", "target_config.yaml", "target_config_full.yaml"], + validation_commands=["uv run pytest tests/unit/build_outputs/test_validation.py"], +) +class AreaValidationService: + """Build validation state for all H5 requests handled by one worker.""" + + def __init__( + self, + *, + engine_factory: Callable[[str], Any] | None = None, + query_targets: Callable[[Any, int], Any] | None = None, + batch_constraints: Callable[[Any, list[int]], Mapping[int, Any]] | None = None, + load_target_config: Callable[[Path | str], Mapping[str, Any]] | None = None, + match_rules: Callable[[Any, list[Mapping[str, Any]]], Any] | None = None, + ) -> None: + """Create a validation service with injectable seams for tests.""" + + self._engine_factory = engine_factory + self._query_targets = query_targets + self._batch_constraints = batch_constraints + self._load_target_config = load_target_config + self._match_rules = match_rules + + def prepare_context( + self, + *, + inputs: PublishingInputBundle, + policy: ValidationPolicy, + period: int, + target_config_path: Path | None = None, + validation_config_path: Path | None = None, + ) -> ValidationContext | None: + """Load validation targets and constraints once for a worker. + + Returns `None` when validation is disabled. When validation is enabled + but no target database path exists, this returns an empty context so + callers can still inspect the policy and configured paths. + """ + + if not policy.enabled: + return None + + if inputs.target_db_path is None: + return ValidationContext( + policy=policy, + target_db_path=None, + period=period, + target_config_path=target_config_path, + validation_config_path=validation_config_path, + ) + + engine = self._create_engine(Path(inputs.target_db_path)) + try: + validation_targets = self._query_all_targets(engine, period) + validation_targets = self._apply_validation_rules( + validation_targets, + validation_config_path, + ) + training_mask = self._training_mask( + validation_targets, + target_config_path, + ) + stratum_ids = [ + int(item) for item in validation_targets["stratum_id"].unique().tolist() + ] + constraints_map = self._load_constraints(engine, stratum_ids) + finally: + dispose = getattr(engine, "dispose", None) + if callable(dispose): + dispose() + + return ValidationContext( + policy=policy, + target_db_path=Path(inputs.target_db_path), + period=period, + validation_targets=validation_targets, + training_mask=training_mask, + constraints_map=constraints_map, + target_config_path=target_config_path, + validation_config_path=validation_config_path, + ) + + def _create_engine(self, target_db_path: Path): + if self._engine_factory is not None: + return self._engine_factory(f"sqlite:///{target_db_path}") + + from sqlalchemy import create_engine + + return create_engine(f"sqlite:///{target_db_path}") + + def _query_all_targets(self, engine, period: int): + if self._query_targets is not None: + return self._query_targets(engine, int(period)) + + from policyengine_us_data.calibration.validate_staging import ( + _query_all_active_targets, + ) + + return _query_all_active_targets(engine, int(period)) + + def _load_constraints(self, engine, stratum_ids: list[int]): + if self._batch_constraints is not None: + return self._batch_constraints(engine, stratum_ids) + + from policyengine_us_data.calibration.validate_staging import ( + _batch_stratum_constraints, + ) + + return _batch_stratum_constraints(engine, stratum_ids) + + def _config(self, path: Path | None) -> Mapping[str, Any]: + if path is None: + return {} + + if self._load_target_config is not None: + return self._load_target_config(path) + + from policyengine_us_data.calibration.unified_calibration import ( + load_target_config, + ) + + return load_target_config(path) + + def _match(self, targets, rules: list[Mapping[str, Any]]): + if self._match_rules is not None: + return self._match_rules(targets, rules) + + from policyengine_us_data.calibration.unified_calibration import _match_rules + + return _match_rules(targets, rules) + + def _apply_validation_rules(self, validation_targets, config_path: Path | None): + config = self._config(config_path) + exclude_rules = list(config.get("exclude", [])) + if exclude_rules: + exclude_mask = self._match(validation_targets, exclude_rules) + validation_targets = validation_targets[~exclude_mask].reset_index( + drop=True + ) + + include_rules = list(config.get("include", [])) + if include_rules: + include_mask = self._match(validation_targets, include_rules) + validation_targets = validation_targets[include_mask].reset_index(drop=True) + + return validation_targets + + def _training_mask(self, validation_targets, config_path: Path | None): + config = self._config(config_path) + include_rules = list(config.get("include", [])) + if not include_rules: + return np.ones(len(validation_targets), dtype=bool) + return np.asarray(self._match(validation_targets, include_rules), dtype=bool) diff --git a/policyengine_us_data/build_outputs/worker_session.py b/policyengine_us_data/build_outputs/worker_session.py new file mode 100644 index 000000000..694f7f200 --- /dev/null +++ b/policyengine_us_data/build_outputs/worker_session.py @@ -0,0 +1,245 @@ +"""Worker-scoped local H5 setup contracts.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +import numpy as np + +from policyengine_us_data.pipeline_metadata import pipeline_node + +from .bootstrap import ( + BootstrapScope, + WorkerBootstrapBundle, + WorkerBootstrapStore, + load_entity_graph, +) +from .fingerprinting import PublishingInputBundle +from .geography_loader import CalibrationGeographyLoader +from .source_dataset import PolicyEngineDatasetReader, SourceDatasetSnapshot +from .validation import AreaValidationService, ValidationContext, ValidationPolicy +from .weights import CloneWeightMatrix + +BootstrapStatus = Literal["used", "fallback", "unavailable"] + +__all__ = [ + "BootstrapStatus", + "WorkerSession", + "WorkerSessionFactory", +] + + +@pipeline_node( + id="local_h5_worker_session", + label="WorkerSession", + node_type="library", + description="Worker-scoped local H5 setup state reused across queued requests.", + source_file="policyengine_us_data/build_outputs/worker_session.py", + status="current", + stability="moving", + pathways=["local_h5"], + artifacts_in=[ + "calibration_weights.npy", + "source_imputed_stratified_extended_cps.h5", + "geography_assignment.npz", + "policy_data.db", + "worker_bootstrap.json", + ], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_session.py" + ], +) +@dataclass +class WorkerSession: + """Prepared local H5 state for one worker process.""" + + inputs: PublishingInputBundle + scope: BootstrapScope + source: SourceDatasetSnapshot + weights: CloneWeightMatrix + geography: Any + validation_context: ValidationContext | None = None + bootstrap_bundle: WorkerBootstrapBundle | None = None + bootstrap_status: BootstrapStatus = "unavailable" + caches: dict[str, Any] = field(default_factory=dict) + + +@pipeline_node( + id="local_h5_worker_session_factory", + label="WorkerSessionFactory", + node_type="library", + description="Load local H5 source, weights, geography, and validation context once per worker.", + source_file="policyengine_us_data/build_outputs/worker_session.py", + status="current", + stability="moving", + pathways=["local_h5"], + artifacts_in=[ + "calibration_weights.npy", + "source_imputed_stratified_extended_cps.h5", + "geography_assignment.npz", + "policy_data.db", + "bootstrap/{scope}/worker_bootstrap.json", + "bootstrap/{scope}/entity_graph.npz", + ], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_session.py", + "uv run pytest tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py", + ], +) +class WorkerSessionFactory: + """Build worker-scoped setup from raw inputs or persisted bootstrap facts.""" + + def __init__( + self, + *, + dataset_reader: PolicyEngineDatasetReader | None = None, + geography_loader: CalibrationGeographyLoader | None = None, + validation_service: AreaValidationService | None = None, + bootstrap_store: WorkerBootstrapStore | None = None, + ) -> None: + """Create a session factory with injectable seams for tests.""" + + self._dataset_reader = dataset_reader or PolicyEngineDatasetReader() + self._geography_loader = geography_loader or CalibrationGeographyLoader() + self._validation_service = validation_service or AreaValidationService() + self._bootstrap_store = bootstrap_store + + def create( + self, + *, + inputs: PublishingInputBundle, + scope: BootstrapScope, + validation_policy: ValidationPolicy | None = None, + period: int = 2024, + target_config_path: Path | None = None, + validation_config_path: Path | None = None, + artifacts_dir: Path | None = None, + ) -> WorkerSession: + """Create a worker session for one local H5 scope. + + Bootstrap artifacts are preferred when present. If they are missing, + stale, or unreadable, the factory falls back to raw source loaders so + rollout can remain dual-path until the bootstrap contract is mandatory. + """ + + bootstrap_store = self._bootstrap_store + if bootstrap_store is None and artifacts_dir is not None: + bootstrap_store = WorkerBootstrapStore(artifacts_dir) + + bundle, bootstrap_error = self._load_bootstrap( + bootstrap_store=bootstrap_store, + scope=scope, + ) + source, bootstrap_status, source_error = self._load_source( + inputs=inputs, + bundle=bundle, + ) + if bootstrap_error is not None and bootstrap_status == "unavailable": + bootstrap_status = "fallback" + fallback_error = source_error or bootstrap_error + + weights = self._load_weights(inputs=inputs, source=source) + geography = self._geography_loader.load( + weights_path=inputs.weights_path, + n_records=weights.n_records, + n_clones=weights.n_clones, + geography_path=inputs.exact_geography_path, + blocks_path=inputs.legacy_blocks_path, + calibration_package_path=inputs.calibration_package_path, + ) + + policy = validation_policy or ValidationPolicy() + validation_context = self._validation_service.prepare_context( + inputs=inputs, + policy=policy, + period=period, + target_config_path=target_config_path, + validation_config_path=validation_config_path, + ) + + caches: dict[str, Any] = {} + if fallback_error is not None: + caches["bootstrap_error"] = str(fallback_error) + + return WorkerSession( + inputs=inputs, + scope=scope, + source=source, + weights=weights, + geography=geography, + validation_context=validation_context, + bootstrap_bundle=bundle if bootstrap_status == "used" else None, + bootstrap_status=bootstrap_status, + caches=caches, + ) + + def _load_bootstrap( + self, + *, + bootstrap_store: WorkerBootstrapStore | None, + scope: BootstrapScope, + ) -> tuple[WorkerBootstrapBundle | None, Exception | None]: + if bootstrap_store is None: + return None, None + + manifest_path = getattr(bootstrap_store, "manifest_path", None) + manifest_exists = False + if callable(manifest_path): + manifest_exists = Path(manifest_path(scope)).exists() + if not manifest_exists: + return None, None + + try: + return bootstrap_store.load(scope), None + except FileNotFoundError as exc: + return None, exc if manifest_exists else None + except Exception as exc: + return None, exc + + def _load_source( + self, + *, + inputs: PublishingInputBundle, + bundle: WorkerBootstrapBundle | None, + ) -> tuple[SourceDatasetSnapshot, BootstrapStatus, Exception | None]: + if bundle is not None: + try: + entity_graph = load_entity_graph(bundle.entity_graph_path) + load_with_entity_graph = getattr( + self._dataset_reader, + "load_with_entity_graph", + ) + return ( + load_with_entity_graph( + inputs.source_dataset_path, + entity_graph, + ), + "used", + None, + ) + except Exception as exc: + source = self._dataset_reader.load(inputs.source_dataset_path) + return source, "fallback", exc + + source = self._dataset_reader.load(inputs.source_dataset_path) + return source, "unavailable", None + + def _load_weights( + self, + *, + inputs: PublishingInputBundle, + source: SourceDatasetSnapshot, + ) -> CloneWeightMatrix: + weights_array = np.load(inputs.weights_path) + weights = CloneWeightMatrix.from_vector( + weights_array, + n_records=source.n_households, + ) + if inputs.n_clones is not None and weights.n_clones != int(inputs.n_clones): + raise ValueError( + f"Weight vector implies n_clones={weights.n_clones}, " + f"expected {inputs.n_clones}" + ) + return weights diff --git a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py index 7365805d8..545aad94f 100644 --- a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py +++ b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py @@ -8,6 +8,8 @@ import numpy as np import pytest +from policyengine_us_data.build_outputs.bootstrap import WorkerBootstrapBuilder +from policyengine_us_data.build_outputs.fingerprinting import PublishingInputBundle from policyengine_us_data.build_outputs.source_dataset import ( DEFAULT_SUBENTITIES, PolicyEngineDatasetReader, @@ -38,6 +40,9 @@ def _run_worker( validate: bool = False, target_config: Path | None = None, validation_config: Path | None = None, + run_id: str = "tiny-worker-run", + artifacts_dir: Path | None = None, + return_process: bool = False, ) -> dict: _require_worker_dependencies() if not isinstance(requests, (list, tuple)): @@ -56,9 +61,15 @@ def _run_worker( str(artifacts.db_path), "--output-dir", str(output_dir), + "--run-id", + run_id, + "--run-config-path", + str(artifacts.run_config_path), "--n-clones", str(artifacts.n_clones), ] + if artifacts_dir is not None: + cmd.extend(["--artifacts-dir", str(artifacts_dir)]) if not validate: cmd.append("--no-validate") if target_config is not None: @@ -81,6 +92,8 @@ def _run_worker( text=True, check=True, ) + if return_process: + return result return json.loads(result.stdout) @@ -191,21 +204,64 @@ def test_worker_validation_runs_for_tiny_district_state_and_national_h5s(tmp_pat validate=True, target_config=target_config, validation_config=validation_config, + return_process=True, ) + parsed = json.loads(result.stdout) - assert result["failed"] == [] - assert result["errors"] == [] - assert result["completed"] == ["district:NC-01", "state:NC", "national:US"] - assert len(result["validation_rows"]) == 3 - assert set(result["validation_summary"]) == { + assert result.stderr.count("Worker session ready:") == 1 + assert parsed["failed"] == [] + assert parsed["errors"] == [] + assert parsed["completed"] == ["district:NC-01", "state:NC", "national:US"] + assert len(parsed["validation_rows"]) == 3 + assert set(parsed["validation_summary"]) == { "district:NC-01", "state:NC", "national:US", } - for summary in result["validation_summary"].values(): + for summary in parsed["validation_summary"].values(): assert summary["n_targets"] == 1 assert summary["n_sanity_fail"] == 0 - for row in result["validation_rows"]: + for row in parsed["validation_rows"]: assert row["variable"] == "household_count" assert row["sanity_check"] == "PASS" assert row["in_training"] is True + + +def test_worker_consumes_scope_bootstrap_when_available(tmp_path): + artifacts = seed_local_h5_artifacts(tmp_path / "bootstrap") + request = build_request("district", geography=artifacts.geography) + output_dir = tmp_path / "bootstrap-out" + artifacts_dir = tmp_path / "pipeline-artifacts" / "run-123" + inputs = PublishingInputBundle( + weights_path=artifacts.weights_path, + source_dataset_path=artifacts.dataset_path, + target_db_path=artifacts.db_path, + exact_geography_path=artifacts.geography_path, + calibration_package_path=artifacts.calibration_package_path, + run_config_path=artifacts.run_config_path, + run_id="run-123", + version="0.0.0", + n_clones=artifacts.n_clones, + seed=42, + ) + WorkerBootstrapBuilder().build( + inputs=inputs, + scope="regional", + artifacts_dir=artifacts_dir, + ) + + result = _run_worker( + requests=request, + artifacts=artifacts, + output_dir=output_dir, + use_saved_geography=True, + run_id="run-123", + artifacts_dir=artifacts_dir, + return_process=True, + ) + parsed = json.loads(result.stdout) + + assert "Worker session ready: scope=regional, bootstrap=used" in result.stderr + assert parsed["failed"] == [] + assert parsed["errors"] == [] + assert parsed["completed"] == [f"district:{request.area_id}"] diff --git a/tests/unit/build_outputs/test_validation.py b/tests/unit/build_outputs/test_validation.py new file mode 100644 index 000000000..f5e8c855e --- /dev/null +++ b/tests/unit/build_outputs/test_validation.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd + +from policyengine_us_data.build_outputs.fingerprinting import PublishingInputBundle +from policyengine_us_data.build_outputs.validation import ( + AreaValidationService, + ValidationContext, + ValidationPolicy, +) + + +def _inputs(tmp_path: Path, *, with_db: bool = True) -> PublishingInputBundle: + return PublishingInputBundle( + weights_path=tmp_path / "weights.npy", + source_dataset_path=tmp_path / "source.h5", + target_db_path=tmp_path / "policy_data.db" if with_db else None, + exact_geography_path=tmp_path / "geography.npz", + calibration_package_path=None, + run_config_path=None, + run_id="run-123", + version="0.0.0", + n_clones=2, + seed=42, + ) + + +def test_validation_service_returns_none_when_disabled(tmp_path): + service = AreaValidationService() + + context = service.prepare_context( + inputs=_inputs(tmp_path), + policy=ValidationPolicy(enabled=False), + period=2024, + ) + + assert context is None + + +def test_validation_service_returns_empty_context_without_db_path(tmp_path): + service = AreaValidationService() + + context = service.prepare_context( + inputs=_inputs(tmp_path, with_db=False), + policy=ValidationPolicy(), + period=2024, + target_config_path=tmp_path / "target_config.yaml", + ) + + assert isinstance(context, ValidationContext) + assert context.target_db_path is None + assert context.validation_targets is None + assert context.training_mask is None + assert context.target_config_path == tmp_path / "target_config.yaml" + + +def test_validation_service_prepares_targets_training_mask_and_constraints(tmp_path): + engine_urls = [] + constraint_calls = [] + disposed = [] + target_config = tmp_path / "target_config.yaml" + validation_config = tmp_path / "validation_config.yaml" + + class FakeEngine: + def dispose(self): + disposed.append(True) + + def engine_factory(url: str): + engine_urls.append(url) + return FakeEngine() + + def query_targets(engine, period: int): + assert period == 2024 + return pd.DataFrame( + { + "variable": ["household_count", "income", "rent"], + "stratum_id": [1, 2, 3], + "geo_level": ["state", "state", "state"], + "geographic_id": ["37", "37", "37"], + } + ) + + def batch_constraints(engine, stratum_ids: list[int]): + constraint_calls.append(tuple(stratum_ids)) + return {stratum_id: [f"constraint-{stratum_id}"] for stratum_id in stratum_ids} + + def load_config(path: Path | str): + if Path(path) == validation_config: + return {"exclude": [{"variable": "rent"}]} + if Path(path) == target_config: + return {"include": [{"variable": "income"}]} + return {} + + def match_rules(targets, rules): + variables = {rule["variable"] for rule in rules} + return targets["variable"].isin(variables).to_numpy() + + service = AreaValidationService( + engine_factory=engine_factory, + query_targets=query_targets, + batch_constraints=batch_constraints, + load_target_config=load_config, + match_rules=match_rules, + ) + + context = service.prepare_context( + inputs=_inputs(tmp_path), + policy=ValidationPolicy(), + period=2024, + target_config_path=target_config, + validation_config_path=validation_config, + ) + + assert engine_urls == [f"sqlite:///{tmp_path / 'policy_data.db'}"] + assert disposed == [True] + assert context.validation_targets["variable"].tolist() == [ + "household_count", + "income", + ] + assert np.array_equal(context.training_mask, np.array([False, True])) + assert context.constraints_map == { + 1: ["constraint-1"], + 2: ["constraint-2"], + } + assert constraint_calls == [(1, 2)] diff --git a/tests/unit/build_outputs/test_worker_session.py b/tests/unit/build_outputs/test_worker_session.py new file mode 100644 index 000000000..06902060b --- /dev/null +++ b/tests/unit/build_outputs/test_worker_session.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pytest + +from policyengine_us_data.build_outputs.bootstrap import ( + WorkerBootstrapBuilder, + WorkerBootstrapStore, +) +from policyengine_us_data.build_outputs.validation import ( + ValidationContext, + ValidationPolicy, +) +from policyengine_us_data.build_outputs.worker_session import ( + WorkerSession, + WorkerSessionFactory, +) +from tests.support.build_outputs.bootstrap import ( + FakeDatasetReader, + FakeFingerprintingService, + FakeGeographyLoader, + make_bootstrap_test_artifacts, +) + + +class SessionDatasetReader(FakeDatasetReader): + """Dataset reader fake that records raw and bootstrap source loads.""" + + def __init__(self, snapshot, *, fail_with_entity_graph: bool = False): + super().__init__(snapshot) + self.loaded_with_entity_graph: list[tuple[Path, object]] = [] + self.fail_with_entity_graph = fail_with_entity_graph + + def load_with_entity_graph(self, dataset_path: Path, entity_graph): + self.loaded_with_entity_graph.append((Path(dataset_path), entity_graph)) + if self.fail_with_entity_graph: + raise RuntimeError("entity graph load failed") + return self.snapshot + + +class SessionGeographyLoader(FakeGeographyLoader): + """Geography loader fake that records load calls.""" + + def __init__(self, artifacts): + super().__init__(artifacts) + self.load_calls = [] + + def load(self, **kwargs): + self.load_calls.append(kwargs) + return super().load(**kwargs) + + +class FakeValidationService: + """Validation service fake returning a prepared context.""" + + def __init__(self): + self.calls = [] + + def prepare_context(self, **kwargs): + self.calls.append(kwargs) + policy = kwargs["policy"] + if not policy.enabled: + return None + return ValidationContext( + policy=policy, + target_db_path=kwargs["inputs"].target_db_path, + period=kwargs["period"], + validation_targets=SimpleNamespace(name="targets"), + training_mask=np.array([True]), + constraints_map={1: []}, + ) + + +def test_worker_session_caches_are_per_session(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + + first = WorkerSession( + inputs=artifacts.inputs, + scope="regional", + source=artifacts.snapshot, + weights=SimpleNamespace(values=np.ones(2), n_records=1, n_clones=2), + geography=SimpleNamespace(n_records=1, n_clones=2), + ) + second = WorkerSession( + inputs=artifacts.inputs, + scope="regional", + source=artifacts.snapshot, + weights=SimpleNamespace(values=np.ones(2), n_records=1, n_clones=2), + geography=SimpleNamespace(n_records=1, n_clones=2), + ) + + first.caches["marker"] = "first" + + assert second.caches == {} + + +def test_worker_session_factory_uses_raw_loaders_without_bootstrap(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + dataset_reader = SessionDatasetReader(artifacts.snapshot) + geography_loader = SessionGeographyLoader(artifacts) + validation_service = FakeValidationService() + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=geography_loader, + validation_service=validation_service, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "unavailable" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert dataset_reader.loaded_with_entity_graph == [] + assert session.weights.n_records == artifacts.n_records + assert session.weights.n_clones == artifacts.n_clones + assert geography_loader.load_calls[0]["n_records"] == artifacts.n_records + assert validation_service.calls[0]["inputs"] == artifacts.inputs + + +def test_worker_session_factory_prefers_bootstrap_entity_graph(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "used" + assert session.bootstrap_bundle is not None + assert dataset_reader.loaded_paths == [] + assert dataset_reader.loaded_with_entity_graph[0][0] == ( + artifacts.inputs.source_dataset_path + ) + + +def test_worker_session_factory_falls_back_when_bootstrap_source_load_fails(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + dataset_reader = SessionDatasetReader( + artifacts.snapshot, + fail_with_entity_graph=True, + ) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "entity graph load failed" in session.caches["bootstrap_error"] + + +def test_worker_session_factory_marks_corrupt_bootstrap_as_fallback(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + store.scope_dir("regional").mkdir(parents=True) + store.manifest_path("regional").write_text("{not-json") + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "Expecting property name" in session.caches["bootstrap_error"] + + +def test_worker_session_factory_rejects_weight_clone_mismatch(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs", n_clones=2) + bad_inputs = type(artifacts.inputs)( + weights_path=artifacts.inputs.weights_path, + source_dataset_path=artifacts.inputs.source_dataset_path, + target_db_path=artifacts.inputs.target_db_path, + exact_geography_path=artifacts.inputs.exact_geography_path, + calibration_package_path=artifacts.inputs.calibration_package_path, + run_config_path=artifacts.inputs.run_config_path, + run_id=artifacts.inputs.run_id, + version=artifacts.inputs.version, + n_clones=3, + seed=artifacts.inputs.seed, + legacy_blocks_path=artifacts.inputs.legacy_blocks_path, + ) + + with pytest.raises(ValueError, match="expected 3"): + WorkerSessionFactory( + dataset_reader=SessionDatasetReader(artifacts.snapshot), + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + ).create( + inputs=bad_inputs, + scope="regional", + validation_policy=ValidationPolicy(enabled=False), + period=2024, + ) diff --git a/tests/unit/test_modal_worker_script.py b/tests/unit/test_modal_worker_script.py index d2e90f46b..3ddcdce99 100644 --- a/tests/unit/test_modal_worker_script.py +++ b/tests/unit/test_modal_worker_script.py @@ -53,6 +53,33 @@ def test_parse_args_accepts_calibration_package_path(): assert args.calibration_package_path == "/tmp/calibration_package.pkl" +def test_parse_args_accepts_worker_session_paths(): + args = worker_script.parse_args( + [ + "--requests-json", + "[]", + "--weights-path", + "/tmp/weights.npy", + "--dataset-path", + "/tmp/source.h5", + "--db-path", + "/tmp/policy_data.db", + "--output-dir", + "/tmp/out", + "--run-id", + "run-123", + "--artifacts-dir", + "/tmp/artifacts/run-123", + "--run-config-path", + "/tmp/unified_run_config.json", + ] + ) + + assert args.run_id == "run-123" + assert args.artifacts_dir == "/tmp/artifacts/run-123" + assert args.run_config_path == "/tmp/unified_run_config.json" + + def test_load_request_inputs_from_args_uses_request_payloads_when_present(): args = SimpleNamespace( requests_json=json.dumps([{"area_type": "national", "area_id": "US"}]), @@ -84,6 +111,33 @@ def test_load_request_inputs_from_args_keeps_legacy_work_items_raw(): assert work_items == ({"type": "national", "id": "US"},) +def test_infer_worker_scope_uses_national_only_for_national_bootstrap(): + assert ( + worker_script._infer_worker_scope( + "requests", + (FakeRequest(area_type="national", area_id="US"),), + ) + == "national" + ) + assert ( + worker_script._infer_worker_scope( + "requests", + ( + FakeRequest(area_type="district", area_id="NC-01"), + FakeRequest(area_type="national", area_id="US"), + ), + ) + == "regional" + ) + assert ( + worker_script._infer_worker_scope( + "work_items", + ({"type": "national", "id": "US"},), + ) + == "national" + ) + + def test_work_item_key_handles_missing_fields(): assert worker_script._work_item_key({"type": "district"}) == "district:" assert worker_script._work_item_key(["not-a-dict"]) == "unknown:" From 36dce234880e49d8ca5518cb403e7bb74d68c6c3 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 16:01:21 +0200 Subject: [PATCH 02/17] Pass run config to local H5 workers --- modal_app/local_area.py | 79 ++++++++++++++++++++--------- tests/unit/test_modal_local_area.py | 44 ++++++++++++++++ 2 files changed, 98 insertions(+), 25 deletions(-) diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 92dbb5dc1..c4cc88047 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -489,6 +489,34 @@ def _build_worker_bootstrap( return bundle +def _build_worker_calibration_inputs( + *, + weights_path: Path, + geography_path: Path, + dataset_path: Path, + db_path: Path, + n_clones: int, + seed: int, + run_config_path: Path | None = None, + calibration_package_path: Path | None = None, +) -> Dict[str, object]: + """Build the calibration input payload passed to H5 worker subprocesses.""" + + calibration_inputs: Dict[str, object] = { + "weights": str(weights_path), + "geography": str(geography_path), + "dataset": str(dataset_path), + "database": str(db_path), + "n_clones": n_clones, + "seed": seed, + } + if run_config_path is not None and run_config_path.exists(): + calibration_inputs["run_config"] = str(run_config_path) + if calibration_package_path is not None and calibration_package_path.exists(): + calibration_inputs["calibration_package"] = str(calibration_package_path) + return calibration_inputs + + @pipeline_node( PipelineNode( id="coordinate_work_partition", @@ -686,7 +714,7 @@ def build_areas_worker( branch: str, run_id: str, work_items: List[Dict], - calibration_inputs: Dict[str, str], + calibration_inputs: Dict[str, object], validate: bool = True, ) -> Dict: """ @@ -708,11 +736,11 @@ def build_areas_worker( "--work-items", work_items_json, "--weights-path", - calibration_inputs["weights"], + str(calibration_inputs["weights"]), "--dataset-path", - calibration_inputs["dataset"], + str(calibration_inputs["dataset"]), "--db-path", - calibration_inputs["database"], + str(calibration_inputs["database"]), "--output-dir", str(output_dir), "--run-id", @@ -721,12 +749,12 @@ def build_areas_worker( str(Path("/pipeline/artifacts") / run_id), ] if "geography" in calibration_inputs: - worker_cmd.extend(["--geography-path", calibration_inputs["geography"]]) + worker_cmd.extend(["--geography-path", str(calibration_inputs["geography"])]) if "calibration_package" in calibration_inputs: worker_cmd.extend( [ "--calibration-package-path", - calibration_inputs["calibration_package"], + str(calibration_inputs["calibration_package"]), ] ) if "n_clones" in calibration_inputs: @@ -734,7 +762,7 @@ def build_areas_worker( if "seed" in calibration_inputs: worker_cmd.extend(["--seed", str(calibration_inputs["seed"])]) if "run_config" in calibration_inputs: - worker_cmd.extend(["--run-config-path", calibration_inputs["run_config"]]) + worker_cmd.extend(["--run-config-path", str(calibration_inputs["run_config"])]) repo_root = Path("/root/policyengine-us-data") cal_dir = repo_root / "policyengine_us_data" / "calibration" worker_cmd.extend( @@ -1091,16 +1119,16 @@ def coordinate_publish( ) print("All required pipeline artifacts found on volume.") - calibration_inputs = { - "weights": str(weights_path), - "geography": str(geography_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": n_clones, - "seed": 42, - } - if calibration_package_path.exists(): - calibration_inputs["calibration_package"] = str(calibration_package_path) + calibration_inputs = _build_worker_calibration_inputs( + weights_path=weights_path, + geography_path=geography_path, + dataset_path=dataset_path, + db_path=db_path, + n_clones=n_clones, + seed=42, + run_config_path=config_json_path, + calibration_package_path=calibration_package_path, + ) validate_artifacts(config_json_path, artifacts) if validate: @@ -1402,14 +1430,15 @@ def coordinate_national_publish( ) print("All required national pipeline artifacts found.") - calibration_inputs = { - "weights": str(weights_path), - "geography": str(geography_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": n_clones, - "seed": 42, - } + calibration_inputs = _build_worker_calibration_inputs( + weights_path=weights_path, + geography_path=geography_path, + dataset_path=dataset_path, + db_path=db_path, + n_clones=n_clones, + seed=42, + run_config_path=config_json_path, + ) validate_artifacts( config_json_path, artifacts, diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index 08ab03344..ab3bb5ffd 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -328,3 +328,47 @@ def build(self, **kwargs): assert bundle.manifest_path == ( artifacts_dir / "bootstrap" / "regional" / "worker_bootstrap.json" ) + + +def test_build_worker_calibration_inputs_includes_existing_run_config_and_package( + tmp_path, +): + local_area = load_local_area_module() + run_config_path = tmp_path / "unified_run_config.json" + package_path = tmp_path / "calibration_package.pkl" + run_config_path.write_text("{}") + package_path.write_bytes(b"package") + + inputs = local_area._build_worker_calibration_inputs( + weights_path=tmp_path / "calibration_weights.npy", + geography_path=tmp_path / "geography_assignment.npz", + dataset_path=tmp_path / "source.h5", + db_path=tmp_path / "policy_data.db", + n_clones=430, + seed=42, + run_config_path=run_config_path, + calibration_package_path=package_path, + ) + + assert inputs["run_config"] == str(run_config_path) + assert inputs["calibration_package"] == str(package_path) + assert inputs["n_clones"] == 430 + assert inputs["seed"] == 42 + + +def test_build_worker_calibration_inputs_omits_missing_optional_files(tmp_path): + local_area = load_local_area_module() + + inputs = local_area._build_worker_calibration_inputs( + weights_path=tmp_path / "national_calibration_weights.npy", + geography_path=tmp_path / "national_geography_assignment.npz", + dataset_path=tmp_path / "source.h5", + db_path=tmp_path / "policy_data.db", + n_clones=430, + seed=42, + run_config_path=tmp_path / "missing_config.json", + calibration_package_path=tmp_path / "missing_package.pkl", + ) + + assert "run_config" not in inputs + assert "calibration_package" not in inputs From af243306156eb90142a83893c25178699106573a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 16:12:24 +0200 Subject: [PATCH 03/17] Validate local H5 worker bootstrap inputs --- .../build_outputs/worker_session.py | 121 +++++++++++++++++- .../unit/build_outputs/test_worker_session.py | 95 ++++++++++++++ 2 files changed, 215 insertions(+), 1 deletion(-) diff --git a/policyengine_us_data/build_outputs/worker_session.py b/policyengine_us_data/build_outputs/worker_session.py index 694f7f200..355833293 100644 --- a/policyengine_us_data/build_outputs/worker_session.py +++ b/policyengine_us_data/build_outputs/worker_session.py @@ -16,7 +16,7 @@ WorkerBootstrapStore, load_entity_graph, ) -from .fingerprinting import PublishingInputBundle +from .fingerprinting import FingerprintingService, PublishingInputBundle from .geography_loader import CalibrationGeographyLoader from .source_dataset import PolicyEngineDatasetReader, SourceDatasetSnapshot from .validation import AreaValidationService, ValidationContext, ValidationPolicy @@ -97,6 +97,7 @@ def __init__( dataset_reader: PolicyEngineDatasetReader | None = None, geography_loader: CalibrationGeographyLoader | None = None, validation_service: AreaValidationService | None = None, + fingerprinting_service: FingerprintingService | None = None, bootstrap_store: WorkerBootstrapStore | None = None, ) -> None: """Create a session factory with injectable seams for tests.""" @@ -104,6 +105,9 @@ def __init__( self._dataset_reader = dataset_reader or PolicyEngineDatasetReader() self._geography_loader = geography_loader or CalibrationGeographyLoader() self._validation_service = validation_service or AreaValidationService() + self._fingerprinting_service = fingerprinting_service or FingerprintingService( + geography_loader=self._geography_loader + ) self._bootstrap_store = bootstrap_store def create( @@ -132,6 +136,14 @@ def create( bootstrap_store=bootstrap_store, scope=scope, ) + if bundle is not None: + bootstrap_error = self._validate_bootstrap_bundle( + bundle=bundle, + inputs=inputs, + scope=scope, + ) + if bootstrap_error is not None: + bundle = None source, bootstrap_status, source_error = self._load_source( inputs=inputs, bundle=bundle, @@ -198,6 +210,75 @@ def _load_bootstrap( except Exception as exc: return None, exc + def _validate_bootstrap_bundle( + self, + *, + bundle: WorkerBootstrapBundle, + inputs: PublishingInputBundle, + scope: BootstrapScope, + ) -> Exception | None: + try: + self._raise_for_bootstrap_mismatch( + bundle=bundle, + inputs=inputs, + scope=scope, + ) + except Exception as exc: + return exc + return None + + def _raise_for_bootstrap_mismatch( + self, + *, + bundle: WorkerBootstrapBundle, + inputs: PublishingInputBundle, + scope: BootstrapScope, + ) -> None: + if bundle.run_id != inputs.run_id: + raise ValueError( + f"Bootstrap run_id {bundle.run_id!r} does not match " + f"worker run_id {inputs.run_id!r}" + ) + if bundle.scope != scope: + raise ValueError( + f"Bootstrap scope {bundle.scope!r} does not match " + f"worker scope {scope!r}" + ) + + traceability = self._fingerprinting_service.build_traceability( + inputs=inputs, + scope=scope, + ) + current_inputs = { + "weights": _artifact_identity_manifest(traceability.weights), + "source_dataset": _artifact_identity_manifest(traceability.source_dataset), + "exact_geography": _artifact_identity_manifest( + traceability.exact_geography + ), + "target_db": _artifact_identity_manifest(traceability.target_db), + "calibration_package": _artifact_identity_manifest( + traceability.calibration_package + ), + "run_config": _artifact_identity_manifest(traceability.run_config), + } + + for logical_name, current_identity in current_inputs.items(): + _assert_manifest_identity_matches( + logical_name=logical_name, + expected=current_identity, + actual=bundle.inputs.get(logical_name), + ) + + expected_fingerprint = self._fingerprinting_service.compute_scope_fingerprint( + traceability + ) + actual_fingerprint = bundle.traceability.get("scope_fingerprint") + if actual_fingerprint != expected_fingerprint: + raise ValueError( + f"Bootstrap scope fingerprint {actual_fingerprint!r} does not " + f"match current fingerprint {expected_fingerprint!r}" + ) + def _load_source( self, *, @@ -243,3 +324,41 @@ def _load_weights( f"expected {inputs.n_clones}" ) return weights + + +def _artifact_identity_manifest(identity) -> dict[str, Any] | None: + if identity is None: + return None + return { + "logical_name": identity.logical_name, + "sha256": identity.sha256, + "size_bytes": identity.size_bytes, + "metadata": dict(identity.metadata), + } + + +def _assert_manifest_identity_matches( + *, + logical_name: str, + expected: dict[str, Any] | None, + actual, +) -> None: + if expected is None or actual is None: + if expected != actual: + raise ValueError( + f"Bootstrap {logical_name} identity presence does not match " + "current inputs" + ) + return + + actual_manifest = dict(actual) + comparable_actual = { + "logical_name": actual_manifest.get("logical_name"), + "sha256": actual_manifest.get("sha256"), + "size_bytes": actual_manifest.get("size_bytes"), + "metadata": dict(actual_manifest.get("metadata") or {}), + } + if comparable_actual != expected: + raise ValueError( + f"Bootstrap {logical_name} identity does not match current inputs" + ) diff --git a/tests/unit/build_outputs/test_worker_session.py b/tests/unit/build_outputs/test_worker_session.py index 06902060b..922134ba6 100644 --- a/tests/unit/build_outputs/test_worker_session.py +++ b/tests/unit/build_outputs/test_worker_session.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import replace from pathlib import Path from types import SimpleNamespace @@ -74,6 +75,17 @@ def prepare_context(self, **kwargs): ) +class MismatchedWeightFingerprintingService(FakeFingerprintingService): + """Fingerprinting fake that reports a changed current weights identity.""" + + def build_traceability(self, **kwargs): + traceability = super().build_traceability(**kwargs) + return replace( + traceability, + weights=replace(traceability.weights, sha256="sha256:changed-weights"), + ) + + def test_worker_session_caches_are_per_session(tmp_path): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") @@ -142,6 +154,7 @@ def test_worker_session_factory_prefers_bootstrap_entity_graph(tmp_path): dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), + fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=artifacts.inputs, @@ -179,6 +192,7 @@ def test_worker_session_factory_falls_back_when_bootstrap_source_load_fails(tmp_ dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), + fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=artifacts.inputs, @@ -193,6 +207,87 @@ def test_worker_session_factory_falls_back_when_bootstrap_source_load_fails(tmp_ assert "entity graph load failed" in session.caches["bootstrap_error"] +def test_worker_session_factory_falls_back_when_bootstrap_inputs_mismatch(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + changed_inputs = type(artifacts.inputs)( + weights_path=artifacts.inputs.weights_path, + source_dataset_path=artifacts.inputs.source_dataset_path, + target_db_path=artifacts.inputs.target_db_path, + exact_geography_path=artifacts.inputs.exact_geography_path, + calibration_package_path=artifacts.inputs.calibration_package_path, + run_config_path=artifacts.inputs.run_config_path, + run_id="different-run", + version=artifacts.inputs.version, + n_clones=artifacts.inputs.n_clones, + seed=artifacts.inputs.seed, + legacy_blocks_path=artifacts.inputs.legacy_blocks_path, + ) + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + fingerprinting_service=FakeFingerprintingService(), + bootstrap_store=store, + ).create( + inputs=changed_inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "does not match worker run_id" in session.caches["bootstrap_error"] + + +def test_worker_session_factory_falls_back_when_bootstrap_artifact_identity_mismatch( + tmp_path, +): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + fingerprinting_service=MismatchedWeightFingerprintingService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "weights identity does not match" in session.caches["bootstrap_error"] + + def test_worker_session_factory_marks_corrupt_bootstrap_as_fallback(tmp_path): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") store = WorkerBootstrapStore(tmp_path / "artifacts") From 6bec2b61bcd820f4eadf4835bf176554824e79f8 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 16:31:43 +0200 Subject: [PATCH 04/17] Avoid worker-side bootstrap rehashing --- .../build_outputs/worker_session.py | 102 ++++++++---------- .../unit/build_outputs/test_worker_session.py | 34 +++--- 2 files changed, 60 insertions(+), 76 deletions(-) diff --git a/policyengine_us_data/build_outputs/worker_session.py b/policyengine_us_data/build_outputs/worker_session.py index 355833293..06945f465 100644 --- a/policyengine_us_data/build_outputs/worker_session.py +++ b/policyengine_us_data/build_outputs/worker_session.py @@ -16,7 +16,7 @@ WorkerBootstrapStore, load_entity_graph, ) -from .fingerprinting import FingerprintingService, PublishingInputBundle +from .fingerprinting import PublishingInputBundle from .geography_loader import CalibrationGeographyLoader from .source_dataset import PolicyEngineDatasetReader, SourceDatasetSnapshot from .validation import AreaValidationService, ValidationContext, ValidationPolicy @@ -97,7 +97,6 @@ def __init__( dataset_reader: PolicyEngineDatasetReader | None = None, geography_loader: CalibrationGeographyLoader | None = None, validation_service: AreaValidationService | None = None, - fingerprinting_service: FingerprintingService | None = None, bootstrap_store: WorkerBootstrapStore | None = None, ) -> None: """Create a session factory with injectable seams for tests.""" @@ -105,9 +104,6 @@ def __init__( self._dataset_reader = dataset_reader or PolicyEngineDatasetReader() self._geography_loader = geography_loader or CalibrationGeographyLoader() self._validation_service = validation_service or AreaValidationService() - self._fingerprinting_service = fingerprinting_service or FingerprintingService( - geography_loader=self._geography_loader - ) self._bootstrap_store = bootstrap_store def create( @@ -245,39 +241,27 @@ def _raise_for_bootstrap_mismatch( f"worker scope {scope!r}" ) - traceability = self._fingerprinting_service.build_traceability( - inputs=inputs, - scope=scope, - ) - current_inputs = { - "weights": _artifact_identity_manifest(traceability.weights), - "source_dataset": _artifact_identity_manifest(traceability.source_dataset), - "exact_geography": _artifact_identity_manifest( - traceability.exact_geography - ), - "target_db": _artifact_identity_manifest(traceability.target_db), - "calibration_package": _artifact_identity_manifest( - traceability.calibration_package - ), - "run_config": _artifact_identity_manifest(traceability.run_config), + expected_paths = { + "weights": inputs.weights_path, + "source_dataset": inputs.source_dataset_path, + "exact_geography": inputs.exact_geography_path, + "target_db": inputs.target_db_path, + "calibration_package": inputs.calibration_package_path, + "run_config": inputs.run_config_path, } - - for logical_name, current_identity in current_inputs.items(): - _assert_manifest_identity_matches( + for logical_name, expected_path in expected_paths.items(): + _assert_manifest_path_matches( logical_name=logical_name, - expected=current_identity, - actual=bundle.inputs.get(logical_name), + expected_path=expected_path, + manifest_identity=bundle.inputs.get(logical_name), ) - expected_fingerprint = self._fingerprinting_service.compute_scope_fingerprint( - traceability + _assert_summary_field_matches( + section="weights", + field="n_clones", + expected=inputs.n_clones, + actual=bundle.weights.get("n_clones"), ) - actual_fingerprint = bundle.traceability.get("scope_fingerprint") - if actual_fingerprint != expected_fingerprint: - raise ValueError( - f"Bootstrap scope fingerprint {actual_fingerprint!r} does not " - f"match current fingerprint {expected_fingerprint!r}" - ) def _load_source( self, @@ -326,39 +310,41 @@ def _load_weights( return weights -def _artifact_identity_manifest(identity) -> dict[str, Any] | None: - if identity is None: - return None - return { - "logical_name": identity.logical_name, - "sha256": identity.sha256, - "size_bytes": identity.size_bytes, - "metadata": dict(identity.metadata), - } - - -def _assert_manifest_identity_matches( +def _assert_manifest_path_matches( *, logical_name: str, - expected: dict[str, Any] | None, - actual, + expected_path: Path | None, + manifest_identity, ) -> None: - if expected is None or actual is None: - if expected != actual: + if expected_path is None: + if manifest_identity is not None: raise ValueError( f"Bootstrap {logical_name} identity presence does not match " "current inputs" ) return - actual_manifest = dict(actual) - comparable_actual = { - "logical_name": actual_manifest.get("logical_name"), - "sha256": actual_manifest.get("sha256"), - "size_bytes": actual_manifest.get("size_bytes"), - "metadata": dict(actual_manifest.get("metadata") or {}), - } - if comparable_actual != expected: + if manifest_identity is None: + raise ValueError( + f"Bootstrap {logical_name} identity presence does not match current inputs" + ) + + actual_path = manifest_identity.get("path") + if actual_path is None or Path(actual_path) != Path(expected_path): + raise ValueError(f"Bootstrap {logical_name} path does not match current inputs") + + +def _assert_summary_field_matches( + *, + section: str, + field: str, + expected, + actual, +) -> None: + if expected is None: + return + if actual != expected: raise ValueError( - f"Bootstrap {logical_name} identity does not match current inputs" + f"Bootstrap {section}.{field} {actual!r} does not match " + f"current value {expected!r}" ) diff --git a/tests/unit/build_outputs/test_worker_session.py b/tests/unit/build_outputs/test_worker_session.py index 922134ba6..a2d44494c 100644 --- a/tests/unit/build_outputs/test_worker_session.py +++ b/tests/unit/build_outputs/test_worker_session.py @@ -1,6 +1,5 @@ from __future__ import annotations -from dataclasses import replace from pathlib import Path from types import SimpleNamespace @@ -75,17 +74,6 @@ def prepare_context(self, **kwargs): ) -class MismatchedWeightFingerprintingService(FakeFingerprintingService): - """Fingerprinting fake that reports a changed current weights identity.""" - - def build_traceability(self, **kwargs): - traceability = super().build_traceability(**kwargs) - return replace( - traceability, - weights=replace(traceability.weights, sha256="sha256:changed-weights"), - ) - - def test_worker_session_caches_are_per_session(tmp_path): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") @@ -154,7 +142,6 @@ def test_worker_session_factory_prefers_bootstrap_entity_graph(tmp_path): dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), - fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=artifacts.inputs, @@ -192,7 +179,6 @@ def test_worker_session_factory_falls_back_when_bootstrap_source_load_fails(tmp_ dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), - fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=artifacts.inputs, @@ -238,7 +224,6 @@ def test_worker_session_factory_falls_back_when_bootstrap_inputs_mismatch(tmp_pa dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), - fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=changed_inputs, @@ -267,16 +252,29 @@ def test_worker_session_factory_falls_back_when_bootstrap_artifact_identity_mism scope="regional", artifacts_dir=store.artifacts_dir, ) + changed_inputs = type(artifacts.inputs)( + weights_path=tmp_path / "other_weights.npy", + source_dataset_path=artifacts.inputs.source_dataset_path, + target_db_path=artifacts.inputs.target_db_path, + exact_geography_path=artifacts.inputs.exact_geography_path, + calibration_package_path=artifacts.inputs.calibration_package_path, + run_config_path=artifacts.inputs.run_config_path, + run_id=artifacts.inputs.run_id, + version=artifacts.inputs.version, + n_clones=artifacts.inputs.n_clones, + seed=artifacts.inputs.seed, + legacy_blocks_path=artifacts.inputs.legacy_blocks_path, + ) + changed_inputs.weights_path.write_bytes(artifacts.inputs.weights_path.read_bytes()) dataset_reader = SessionDatasetReader(artifacts.snapshot) session = WorkerSessionFactory( dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), - fingerprinting_service=MismatchedWeightFingerprintingService(), bootstrap_store=store, ).create( - inputs=artifacts.inputs, + inputs=changed_inputs, scope="regional", validation_policy=ValidationPolicy(), period=2024, @@ -285,7 +283,7 @@ def test_worker_session_factory_falls_back_when_bootstrap_artifact_identity_mism assert session.bootstrap_status == "fallback" assert session.bootstrap_bundle is None assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] - assert "weights identity does not match" in session.caches["bootstrap_error"] + assert "weights path does not match" in session.caches["bootstrap_error"] def test_worker_session_factory_marks_corrupt_bootstrap_as_fallback(tmp_path): From a27057527727bb7bee69c22012075f73d6c77d05 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 16:33:49 +0200 Subject: [PATCH 05/17] Revert "Avoid worker-side bootstrap rehashing" This reverts commit 7093a40bde6e73acd08523b49826863cc1511fec. --- .../build_outputs/worker_session.py | 102 ++++++++++-------- .../unit/build_outputs/test_worker_session.py | 34 +++--- 2 files changed, 76 insertions(+), 60 deletions(-) diff --git a/policyengine_us_data/build_outputs/worker_session.py b/policyengine_us_data/build_outputs/worker_session.py index 06945f465..355833293 100644 --- a/policyengine_us_data/build_outputs/worker_session.py +++ b/policyengine_us_data/build_outputs/worker_session.py @@ -16,7 +16,7 @@ WorkerBootstrapStore, load_entity_graph, ) -from .fingerprinting import PublishingInputBundle +from .fingerprinting import FingerprintingService, PublishingInputBundle from .geography_loader import CalibrationGeographyLoader from .source_dataset import PolicyEngineDatasetReader, SourceDatasetSnapshot from .validation import AreaValidationService, ValidationContext, ValidationPolicy @@ -97,6 +97,7 @@ def __init__( dataset_reader: PolicyEngineDatasetReader | None = None, geography_loader: CalibrationGeographyLoader | None = None, validation_service: AreaValidationService | None = None, + fingerprinting_service: FingerprintingService | None = None, bootstrap_store: WorkerBootstrapStore | None = None, ) -> None: """Create a session factory with injectable seams for tests.""" @@ -104,6 +105,9 @@ def __init__( self._dataset_reader = dataset_reader or PolicyEngineDatasetReader() self._geography_loader = geography_loader or CalibrationGeographyLoader() self._validation_service = validation_service or AreaValidationService() + self._fingerprinting_service = fingerprinting_service or FingerprintingService( + geography_loader=self._geography_loader + ) self._bootstrap_store = bootstrap_store def create( @@ -241,27 +245,39 @@ def _raise_for_bootstrap_mismatch( f"worker scope {scope!r}" ) - expected_paths = { - "weights": inputs.weights_path, - "source_dataset": inputs.source_dataset_path, - "exact_geography": inputs.exact_geography_path, - "target_db": inputs.target_db_path, - "calibration_package": inputs.calibration_package_path, - "run_config": inputs.run_config_path, + traceability = self._fingerprinting_service.build_traceability( + inputs=inputs, + scope=scope, + ) + current_inputs = { + "weights": _artifact_identity_manifest(traceability.weights), + "source_dataset": _artifact_identity_manifest(traceability.source_dataset), + "exact_geography": _artifact_identity_manifest( + traceability.exact_geography + ), + "target_db": _artifact_identity_manifest(traceability.target_db), + "calibration_package": _artifact_identity_manifest( + traceability.calibration_package + ), + "run_config": _artifact_identity_manifest(traceability.run_config), } - for logical_name, expected_path in expected_paths.items(): - _assert_manifest_path_matches( + + for logical_name, current_identity in current_inputs.items(): + _assert_manifest_identity_matches( logical_name=logical_name, - expected_path=expected_path, - manifest_identity=bundle.inputs.get(logical_name), + expected=current_identity, + actual=bundle.inputs.get(logical_name), ) - _assert_summary_field_matches( - section="weights", - field="n_clones", - expected=inputs.n_clones, - actual=bundle.weights.get("n_clones"), + expected_fingerprint = self._fingerprinting_service.compute_scope_fingerprint( + traceability ) + actual_fingerprint = bundle.traceability.get("scope_fingerprint") + if actual_fingerprint != expected_fingerprint: + raise ValueError( + f"Bootstrap scope fingerprint {actual_fingerprint!r} does not " + f"match current fingerprint {expected_fingerprint!r}" + ) def _load_source( self, @@ -310,41 +326,39 @@ def _load_weights( return weights -def _assert_manifest_path_matches( +def _artifact_identity_manifest(identity) -> dict[str, Any] | None: + if identity is None: + return None + return { + "logical_name": identity.logical_name, + "sha256": identity.sha256, + "size_bytes": identity.size_bytes, + "metadata": dict(identity.metadata), + } + + +def _assert_manifest_identity_matches( *, logical_name: str, - expected_path: Path | None, - manifest_identity, + expected: dict[str, Any] | None, + actual, ) -> None: - if expected_path is None: - if manifest_identity is not None: + if expected is None or actual is None: + if expected != actual: raise ValueError( f"Bootstrap {logical_name} identity presence does not match " "current inputs" ) return - if manifest_identity is None: - raise ValueError( - f"Bootstrap {logical_name} identity presence does not match current inputs" - ) - - actual_path = manifest_identity.get("path") - if actual_path is None or Path(actual_path) != Path(expected_path): - raise ValueError(f"Bootstrap {logical_name} path does not match current inputs") - - -def _assert_summary_field_matches( - *, - section: str, - field: str, - expected, - actual, -) -> None: - if expected is None: - return - if actual != expected: + actual_manifest = dict(actual) + comparable_actual = { + "logical_name": actual_manifest.get("logical_name"), + "sha256": actual_manifest.get("sha256"), + "size_bytes": actual_manifest.get("size_bytes"), + "metadata": dict(actual_manifest.get("metadata") or {}), + } + if comparable_actual != expected: raise ValueError( - f"Bootstrap {section}.{field} {actual!r} does not match " - f"current value {expected!r}" + f"Bootstrap {logical_name} identity does not match current inputs" ) diff --git a/tests/unit/build_outputs/test_worker_session.py b/tests/unit/build_outputs/test_worker_session.py index a2d44494c..922134ba6 100644 --- a/tests/unit/build_outputs/test_worker_session.py +++ b/tests/unit/build_outputs/test_worker_session.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import replace from pathlib import Path from types import SimpleNamespace @@ -74,6 +75,17 @@ def prepare_context(self, **kwargs): ) +class MismatchedWeightFingerprintingService(FakeFingerprintingService): + """Fingerprinting fake that reports a changed current weights identity.""" + + def build_traceability(self, **kwargs): + traceability = super().build_traceability(**kwargs) + return replace( + traceability, + weights=replace(traceability.weights, sha256="sha256:changed-weights"), + ) + + def test_worker_session_caches_are_per_session(tmp_path): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") @@ -142,6 +154,7 @@ def test_worker_session_factory_prefers_bootstrap_entity_graph(tmp_path): dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), + fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=artifacts.inputs, @@ -179,6 +192,7 @@ def test_worker_session_factory_falls_back_when_bootstrap_source_load_fails(tmp_ dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), + fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=artifacts.inputs, @@ -224,6 +238,7 @@ def test_worker_session_factory_falls_back_when_bootstrap_inputs_mismatch(tmp_pa dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), + fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=changed_inputs, @@ -252,29 +267,16 @@ def test_worker_session_factory_falls_back_when_bootstrap_artifact_identity_mism scope="regional", artifacts_dir=store.artifacts_dir, ) - changed_inputs = type(artifacts.inputs)( - weights_path=tmp_path / "other_weights.npy", - source_dataset_path=artifacts.inputs.source_dataset_path, - target_db_path=artifacts.inputs.target_db_path, - exact_geography_path=artifacts.inputs.exact_geography_path, - calibration_package_path=artifacts.inputs.calibration_package_path, - run_config_path=artifacts.inputs.run_config_path, - run_id=artifacts.inputs.run_id, - version=artifacts.inputs.version, - n_clones=artifacts.inputs.n_clones, - seed=artifacts.inputs.seed, - legacy_blocks_path=artifacts.inputs.legacy_blocks_path, - ) - changed_inputs.weights_path.write_bytes(artifacts.inputs.weights_path.read_bytes()) dataset_reader = SessionDatasetReader(artifacts.snapshot) session = WorkerSessionFactory( dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), + fingerprinting_service=MismatchedWeightFingerprintingService(), bootstrap_store=store, ).create( - inputs=changed_inputs, + inputs=artifacts.inputs, scope="regional", validation_policy=ValidationPolicy(), period=2024, @@ -283,7 +285,7 @@ def test_worker_session_factory_falls_back_when_bootstrap_artifact_identity_mism assert session.bootstrap_status == "fallback" assert session.bootstrap_bundle is None assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] - assert "weights path does not match" in session.caches["bootstrap_error"] + assert "weights identity does not match" in session.caches["bootstrap_error"] def test_worker_session_factory_marks_corrupt_bootstrap_as_fallback(tmp_path): From 70e2a1b41b80c278b12b3cff6dc3435857dfb227 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 19:53:09 +0200 Subject: [PATCH 06/17] Reject stale H5 scope fingerprints --- modal_app/local_area.py | 17 ++--- modal_app/pipeline.py | 66 +++++++++++++++++-- .../build_outputs/bootstrap.py | 14 +++- tests/unit/build_outputs/test_bootstrap.py | 22 ++++++- tests/unit/test_modal_local_area.py | 40 ++++++----- 5 files changed, 116 insertions(+), 43 deletions(-) diff --git a/modal_app/local_area.py b/modal_app/local_area.py index c4cc88047..3b2237b81 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -427,23 +427,20 @@ def _resolve_scope_fingerprint( scope: str, expected_fingerprint: str = "", ) -> str: - """Compute the scope fingerprint while preserving pinned resume values.""" + """Compute the scope fingerprint and reject stale resume fingerprints.""" service = FingerprintingService() traceability = service.build_traceability(inputs=inputs, scope=scope) computed_fingerprint = service.compute_scope_fingerprint(traceability) if expected_fingerprint: if expected_fingerprint != computed_fingerprint: - print( - "WARNING: Pinned fingerprint differs from current " - f"{scope} scope fingerprint. " - "Preserving pinned value for backward-compatible resume.\n" - f" Pinned: {expected_fingerprint}\n" - f" Current: {computed_fingerprint}" + raise RuntimeError( + f"Cannot resume {scope} H5 build with changed inputs.\n" + f" Expected: {expected_fingerprint}\n" + f" Current: {computed_fingerprint}\n" + "Start a fresh run or clear stale staged outputs explicitly." ) - else: - print(f"Using pinned fingerprint from pipeline: {expected_fingerprint}") - return expected_fingerprint + print(f"Validated expected {scope} fingerprint: {expected_fingerprint}") return computed_fingerprint diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index 12b42d2c8..a36dae1f9 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -249,6 +249,8 @@ def archive_diagnostics( from modal_app.local_area import ( # noqa: E402 coordinate_publish, coordinate_national_publish, + _build_publishing_input_bundle, + _resolve_scope_fingerprint, ) app.include(_local_area_app) @@ -1419,6 +1421,50 @@ def run_pipeline( "skip_upload": False, "skip_national": skip_national, } + regional_fingerprint_inputs = _build_publishing_input_bundle( + weights_path=_artifacts_dir(run_id) / "calibration_weights.npy", + dataset_path=_artifacts_dir(run_id) + / "source_imputed_stratified_extended_cps.h5", + db_path=_artifacts_dir(run_id) / "policy_data.db", + geography_path=_artifacts_dir(run_id) / "geography_assignment.npz", + calibration_package_path=_artifacts_dir(run_id) / "calibration_package.pkl", + run_config_path=_artifacts_dir(run_id) / "unified_run_config.json", + run_id=run_id, + version=version, + n_clones=n_clones, + seed=42, + legacy_blocks_path=_artifacts_dir(run_id) / "stacked_blocks.npy", + ) + regional_scope_fingerprint = _resolve_scope_fingerprint( + inputs=regional_fingerprint_inputs, + scope="regional", + ) + regional_h5_inputs["h5_scope_fingerprint"] = regional_scope_fingerprint + + national_scope_fingerprint = None + if not skip_national: + national_fingerprint_inputs = _build_publishing_input_bundle( + weights_path=_artifacts_dir(run_id) + / "national_calibration_weights.npy", + dataset_path=_artifacts_dir(run_id) + / "source_imputed_stratified_extended_cps.h5", + db_path=_artifacts_dir(run_id) / "policy_data.db", + geography_path=_artifacts_dir(run_id) + / "national_geography_assignment.npz", + calibration_package_path=None, + run_config_path=_artifacts_dir(run_id) + / "national_unified_run_config.json", + run_id=run_id, + version=version, + n_clones=n_clones, + seed=42, + ) + national_scope_fingerprint = _resolve_scope_fingerprint( + inputs=national_fingerprint_inputs, + scope="national", + ) + national_h5_inputs["h5_scope_fingerprint"] = national_scope_fingerprint + regional_h5_reuse = _step_reusable( meta, LOCAL_AREA_H5_REGIONAL, @@ -1475,9 +1521,6 @@ def run_pipeline( n_clones=n_clones, validate=True, run_id=run_id, - expected_fingerprint=( - meta.regional_fingerprint or meta.fingerprint or "" - ), ) print(f" → coordinate_publish fc: {regional_h5_handle.object_id}") regional_h5_manifest = _start_step_manifest( @@ -1542,12 +1585,16 @@ def run_pipeline( if isinstance(regional_h5_result, dict) and regional_h5_result.get( "fingerprint" ): - meta.regional_fingerprint = regional_h5_result["fingerprint"] - meta.fingerprint = regional_h5_result["fingerprint"] + if regional_h5_result["fingerprint"] != regional_scope_fingerprint: + raise RuntimeError( + "Regional H5 fingerprint changed between pipeline " + "reuse planning and child publish completion.\n" + f" Planned: {regional_scope_fingerprint}\n" + f" Actual: {regional_h5_result['fingerprint']}" + ) regional_h5_manifest.input_identities["h5_scope_fingerprint"] = ( regional_h5_result["fingerprint"] ) - write_run_meta(meta, pipeline_volume) regional_reuse_measurement = ReuseMeasurement.from_dict( regional_h5_result.get("reuse_measurement", {}) if isinstance(regional_h5_result, dict) @@ -1571,6 +1618,13 @@ def run_pipeline( if isinstance(national_h5_result, dict) and national_h5_result.get( "fingerprint" ): + if national_h5_result["fingerprint"] != national_scope_fingerprint: + raise RuntimeError( + "National H5 fingerprint changed between pipeline " + "reuse planning and child publish completion.\n" + f" Planned: {national_scope_fingerprint}\n" + f" Actual: {national_h5_result['fingerprint']}" + ) national_h5_manifest.input_identities["h5_scope_fingerprint"] = ( national_h5_result["fingerprint"] ) diff --git a/policyengine_us_data/build_outputs/bootstrap.py b/policyengine_us_data/build_outputs/bootstrap.py index 337d07053..e4cd7f39a 100644 --- a/policyengine_us_data/build_outputs/bootstrap.py +++ b/policyengine_us_data/build_outputs/bootstrap.py @@ -368,10 +368,18 @@ def build( inputs=inputs, scope=scope, ) - if scope_fingerprint is None: - scope_fingerprint = self._fingerprinting_service.compute_scope_fingerprint( - traceability + computed_scope_fingerprint = ( + self._fingerprinting_service.compute_scope_fingerprint(traceability) + ) + if ( + scope_fingerprint is not None + and scope_fingerprint != computed_scope_fingerprint + ): + raise ValueError( + f"Bootstrap fingerprint {scope_fingerprint!r} does not match " + f"computed {scope} fingerprint {computed_scope_fingerprint!r}" ) + scope_fingerprint = computed_scope_fingerprint entity_graph_path = store.entity_graph_path(scope) save_entity_graph(snapshot.entity_graph, entity_graph_path) diff --git a/tests/unit/build_outputs/test_bootstrap.py b/tests/unit/build_outputs/test_bootstrap.py index 3a60b4726..bd4c55700 100644 --- a/tests/unit/build_outputs/test_bootstrap.py +++ b/tests/unit/build_outputs/test_bootstrap.py @@ -95,7 +95,7 @@ def test_worker_bootstrap_builder_persists_manifest_and_entity_graph(tmp_path): assert manifest["inputs"]["weights"]["sha256"] == "sha256:weights" -def test_worker_bootstrap_builder_preserves_resolved_fingerprint_override(tmp_path): +def test_worker_bootstrap_builder_accepts_matching_resolved_fingerprint(tmp_path): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") bundle = WorkerBootstrapBuilder( @@ -106,11 +106,27 @@ def test_worker_bootstrap_builder_preserves_resolved_fingerprint_override(tmp_pa inputs=artifacts.inputs, scope="regional", artifacts_dir=tmp_path / "artifacts", - scope_fingerprint="pinned-fingerprint", + scope_fingerprint="regional-fingerprint", ) manifest = json.loads(bundle.manifest_path.read_text()) - assert manifest["traceability"]["scope_fingerprint"] == "pinned-fingerprint" + assert manifest["traceability"]["scope_fingerprint"] == "regional-fingerprint" + + +def test_worker_bootstrap_builder_rejects_mismatched_resolved_fingerprint(tmp_path): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + + with pytest.raises(ValueError, match="Bootstrap fingerprint"): + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=tmp_path / "artifacts", + scope_fingerprint="stale-fingerprint", + ) def test_worker_bootstrap_store_loads_persisted_bundle(tmp_path): diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index ab3bb5ffd..3a1cb4f5c 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -1,6 +1,8 @@ from pathlib import Path from types import SimpleNamespace +import pytest + from tests.support.modal_local_area import load_local_area_module @@ -202,7 +204,9 @@ def compute_scope_fingerprint(self, traceability): assert seen["traceability"] == {"scope": "regional", "run_id": "run-123"} -def test_resolve_scope_fingerprint_preserves_matching_pin(monkeypatch, capsys): +def test_resolve_scope_fingerprint_validates_matching_expected_value( + monkeypatch, capsys +): local_area = load_local_area_module(stub_policyengine=False) class FakeFingerprintingService: @@ -210,7 +214,7 @@ def build_traceability(self, *, inputs, scope): return scope def compute_scope_fingerprint(self, traceability): - return "pinned-fingerprint" + return "expected-fingerprint" monkeypatch.setattr( local_area, @@ -234,17 +238,15 @@ def compute_scope_fingerprint(self, traceability): fingerprint = local_area._resolve_scope_fingerprint( inputs=bundle, scope="regional", - expected_fingerprint="pinned-fingerprint", + expected_fingerprint="expected-fingerprint", ) captured = capsys.readouterr() - assert fingerprint == "pinned-fingerprint" - assert "Using pinned fingerprint from pipeline" in captured.out + assert fingerprint == "expected-fingerprint" + assert "Validated expected regional fingerprint" in captured.out -def test_resolve_scope_fingerprint_warns_and_preserves_mismatched_pin( - monkeypatch, capsys -): +def test_resolve_scope_fingerprint_rejects_mismatched_expected_value(monkeypatch): local_area = load_local_area_module(stub_policyengine=False) class FakeFingerprintingService: @@ -273,19 +275,15 @@ def compute_scope_fingerprint(self, traceability): seed=42, ) - fingerprint = local_area._resolve_scope_fingerprint( - inputs=bundle, - scope="national", - expected_fingerprint="legacy-fingerprint", - ) - - captured = capsys.readouterr() - assert fingerprint == "legacy-fingerprint" - assert "Pinned fingerprint differs from current national scope fingerprint" in ( - captured.out - ) - assert "legacy-fingerprint" in captured.out - assert "computed-fingerprint" in captured.out + with pytest.raises( + RuntimeError, + match="Cannot resume national H5 build with changed inputs", + ): + local_area._resolve_scope_fingerprint( + inputs=bundle, + scope="national", + expected_fingerprint="legacy-fingerprint", + ) def test_build_worker_bootstrap_invokes_builder_without_changing_inputs(monkeypatch): From 15f3f10b88e736d4e2c723198a64db4561bf683a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 20:37:42 +0200 Subject: [PATCH 07/17] Fail closed on missing H5 fingerprints --- .../{h5-worker-session.added => 951.added} | 0 modal_app/pipeline.py | 65 ++++++++++++------- tests/unit/test_pipeline_source_contracts.py | 12 ++++ 3 files changed, 53 insertions(+), 24 deletions(-) rename changelog.d/{h5-worker-session.added => 951.added} (100%) diff --git a/changelog.d/h5-worker-session.added b/changelog.d/951.added similarity index 100% rename from changelog.d/h5-worker-session.added rename to changelog.d/951.added diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index a36dae1f9..0fafc2f2b 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -158,6 +158,35 @@ def _calibration_package_parameters( return {key: value for key, value in params.items() if value is not None} +def _require_h5_scope_fingerprint( + result, + *, + scope: str, + planned_fingerprint: str, +) -> str: + """Return a child H5 publish fingerprint or fail closed.""" + + if not isinstance(result, dict): + raise RuntimeError( + f"{scope} H5 publish returned a non-dict result; " + "cannot verify scope fingerprint." + ) + actual_fingerprint = result.get("fingerprint") + if not actual_fingerprint: + raise RuntimeError( + f"{scope} H5 publish result did not include a fingerprint; " + "refusing to complete the H5 step manifest." + ) + if actual_fingerprint != planned_fingerprint: + raise RuntimeError( + f"{scope} H5 fingerprint changed between pipeline " + "reuse planning and child publish completion.\n" + f" Planned: {planned_fingerprint}\n" + f" Actual: {actual_fingerprint}" + ) + return actual_fingerprint + + def get_pinned_sha(branch: str) -> str: """Get the current tip SHA for a branch from GitHub.""" result = subprocess.run( @@ -1582,19 +1611,13 @@ def run_pipeline( pipeline_volume.reload() staging_volume.reload() - if isinstance(regional_h5_result, dict) and regional_h5_result.get( - "fingerprint" - ): - if regional_h5_result["fingerprint"] != regional_scope_fingerprint: - raise RuntimeError( - "Regional H5 fingerprint changed between pipeline " - "reuse planning and child publish completion.\n" - f" Planned: {regional_scope_fingerprint}\n" - f" Actual: {regional_h5_result['fingerprint']}" - ) - regional_h5_manifest.input_identities["h5_scope_fingerprint"] = ( - regional_h5_result["fingerprint"] + regional_h5_manifest.input_identities["h5_scope_fingerprint"] = ( + _require_h5_scope_fingerprint( + regional_h5_result, + scope="Regional", + planned_fingerprint=regional_scope_fingerprint, ) + ) regional_reuse_measurement = ReuseMeasurement.from_dict( regional_h5_result.get("reuse_measurement", {}) if isinstance(regional_h5_result, dict) @@ -1615,19 +1638,13 @@ def run_pipeline( active_step_manifest = national_h5_manifest if national_h5_handle is not None: - if isinstance(national_h5_result, dict) and national_h5_result.get( - "fingerprint" - ): - if national_h5_result["fingerprint"] != national_scope_fingerprint: - raise RuntimeError( - "National H5 fingerprint changed between pipeline " - "reuse planning and child publish completion.\n" - f" Planned: {national_scope_fingerprint}\n" - f" Actual: {national_h5_result['fingerprint']}" - ) - national_h5_manifest.input_identities["h5_scope_fingerprint"] = ( - national_h5_result["fingerprint"] + national_h5_manifest.input_identities["h5_scope_fingerprint"] = ( + _require_h5_scope_fingerprint( + national_h5_result, + scope="National", + planned_fingerprint=national_scope_fingerprint, ) + ) national_reuse_measurement = ReuseMeasurement.from_dict( national_h5_result.get("reuse_measurement", {}) if isinstance(national_h5_result, dict) diff --git a/tests/unit/test_pipeline_source_contracts.py b/tests/unit/test_pipeline_source_contracts.py index 89bd368e4..ee6b45de2 100644 --- a/tests/unit/test_pipeline_source_contracts.py +++ b/tests/unit/test_pipeline_source_contracts.py @@ -87,6 +87,18 @@ def test_run_pipeline_refreshes_diagnostics_even_when_h5_outputs_reused() -> Non assert "Upload validation diagnostics even when H5 outputs are reused." in source +def test_run_pipeline_fails_closed_when_h5_child_fingerprint_missing() -> None: + tree = ast.parse(PIPELINE_SOURCE.read_text()) + helper = _function_def(tree, "_require_h5_scope_fingerprint") + helper_source = ast.get_source_segment(PIPELINE_SOURCE.read_text(), helper) + run_pipeline = _function_def(tree, "run_pipeline") + pipeline_source = ast.get_source_segment(PIPELINE_SOURCE.read_text(), run_pipeline) + + assert "did not include a fingerprint" in helper_source + assert "refusing to complete the H5 step manifest" in helper_source + assert pipeline_source.count("_require_h5_scope_fingerprint(") == 2 + + def test_full_release_path_combines_base_regional_and_national_outputs(): tree = ast.parse(PIPELINE_SOURCE.read_text()) helper = _function_def(tree, "_full_release_staging_rel_paths") From cc42c1c6c78857f9a6bc8b74c0c8add06978e281 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 21:13:49 +0200 Subject: [PATCH 08/17] Defer strict H5 fingerprint guards --- modal_app/local_area.py | 17 +-- modal_app/pipeline.py | 103 +++--------------- .../build_outputs/bootstrap.py | 14 +-- tests/unit/build_outputs/test_bootstrap.py | 22 +--- tests/unit/test_modal_local_area.py | 40 +++---- tests/unit/test_pipeline_source_contracts.py | 12 -- 6 files changed, 53 insertions(+), 155 deletions(-) diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 3b2237b81..c4cc88047 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -427,20 +427,23 @@ def _resolve_scope_fingerprint( scope: str, expected_fingerprint: str = "", ) -> str: - """Compute the scope fingerprint and reject stale resume fingerprints.""" + """Compute the scope fingerprint while preserving pinned resume values.""" service = FingerprintingService() traceability = service.build_traceability(inputs=inputs, scope=scope) computed_fingerprint = service.compute_scope_fingerprint(traceability) if expected_fingerprint: if expected_fingerprint != computed_fingerprint: - raise RuntimeError( - f"Cannot resume {scope} H5 build with changed inputs.\n" - f" Expected: {expected_fingerprint}\n" - f" Current: {computed_fingerprint}\n" - "Start a fresh run or clear stale staged outputs explicitly." + print( + "WARNING: Pinned fingerprint differs from current " + f"{scope} scope fingerprint. " + "Preserving pinned value for backward-compatible resume.\n" + f" Pinned: {expected_fingerprint}\n" + f" Current: {computed_fingerprint}" ) - print(f"Validated expected {scope} fingerprint: {expected_fingerprint}") + else: + print(f"Using pinned fingerprint from pipeline: {expected_fingerprint}") + return expected_fingerprint return computed_fingerprint diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index 0fafc2f2b..12b42d2c8 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -158,35 +158,6 @@ def _calibration_package_parameters( return {key: value for key, value in params.items() if value is not None} -def _require_h5_scope_fingerprint( - result, - *, - scope: str, - planned_fingerprint: str, -) -> str: - """Return a child H5 publish fingerprint or fail closed.""" - - if not isinstance(result, dict): - raise RuntimeError( - f"{scope} H5 publish returned a non-dict result; " - "cannot verify scope fingerprint." - ) - actual_fingerprint = result.get("fingerprint") - if not actual_fingerprint: - raise RuntimeError( - f"{scope} H5 publish result did not include a fingerprint; " - "refusing to complete the H5 step manifest." - ) - if actual_fingerprint != planned_fingerprint: - raise RuntimeError( - f"{scope} H5 fingerprint changed between pipeline " - "reuse planning and child publish completion.\n" - f" Planned: {planned_fingerprint}\n" - f" Actual: {actual_fingerprint}" - ) - return actual_fingerprint - - def get_pinned_sha(branch: str) -> str: """Get the current tip SHA for a branch from GitHub.""" result = subprocess.run( @@ -278,8 +249,6 @@ def archive_diagnostics( from modal_app.local_area import ( # noqa: E402 coordinate_publish, coordinate_national_publish, - _build_publishing_input_bundle, - _resolve_scope_fingerprint, ) app.include(_local_area_app) @@ -1450,50 +1419,6 @@ def run_pipeline( "skip_upload": False, "skip_national": skip_national, } - regional_fingerprint_inputs = _build_publishing_input_bundle( - weights_path=_artifacts_dir(run_id) / "calibration_weights.npy", - dataset_path=_artifacts_dir(run_id) - / "source_imputed_stratified_extended_cps.h5", - db_path=_artifacts_dir(run_id) / "policy_data.db", - geography_path=_artifacts_dir(run_id) / "geography_assignment.npz", - calibration_package_path=_artifacts_dir(run_id) / "calibration_package.pkl", - run_config_path=_artifacts_dir(run_id) / "unified_run_config.json", - run_id=run_id, - version=version, - n_clones=n_clones, - seed=42, - legacy_blocks_path=_artifacts_dir(run_id) / "stacked_blocks.npy", - ) - regional_scope_fingerprint = _resolve_scope_fingerprint( - inputs=regional_fingerprint_inputs, - scope="regional", - ) - regional_h5_inputs["h5_scope_fingerprint"] = regional_scope_fingerprint - - national_scope_fingerprint = None - if not skip_national: - national_fingerprint_inputs = _build_publishing_input_bundle( - weights_path=_artifacts_dir(run_id) - / "national_calibration_weights.npy", - dataset_path=_artifacts_dir(run_id) - / "source_imputed_stratified_extended_cps.h5", - db_path=_artifacts_dir(run_id) / "policy_data.db", - geography_path=_artifacts_dir(run_id) - / "national_geography_assignment.npz", - calibration_package_path=None, - run_config_path=_artifacts_dir(run_id) - / "national_unified_run_config.json", - run_id=run_id, - version=version, - n_clones=n_clones, - seed=42, - ) - national_scope_fingerprint = _resolve_scope_fingerprint( - inputs=national_fingerprint_inputs, - scope="national", - ) - national_h5_inputs["h5_scope_fingerprint"] = national_scope_fingerprint - regional_h5_reuse = _step_reusable( meta, LOCAL_AREA_H5_REGIONAL, @@ -1550,6 +1475,9 @@ def run_pipeline( n_clones=n_clones, validate=True, run_id=run_id, + expected_fingerprint=( + meta.regional_fingerprint or meta.fingerprint or "" + ), ) print(f" → coordinate_publish fc: {regional_h5_handle.object_id}") regional_h5_manifest = _start_step_manifest( @@ -1611,13 +1539,15 @@ def run_pipeline( pipeline_volume.reload() staging_volume.reload() - regional_h5_manifest.input_identities["h5_scope_fingerprint"] = ( - _require_h5_scope_fingerprint( - regional_h5_result, - scope="Regional", - planned_fingerprint=regional_scope_fingerprint, + if isinstance(regional_h5_result, dict) and regional_h5_result.get( + "fingerprint" + ): + meta.regional_fingerprint = regional_h5_result["fingerprint"] + meta.fingerprint = regional_h5_result["fingerprint"] + regional_h5_manifest.input_identities["h5_scope_fingerprint"] = ( + regional_h5_result["fingerprint"] ) - ) + write_run_meta(meta, pipeline_volume) regional_reuse_measurement = ReuseMeasurement.from_dict( regional_h5_result.get("reuse_measurement", {}) if isinstance(regional_h5_result, dict) @@ -1638,13 +1568,12 @@ def run_pipeline( active_step_manifest = national_h5_manifest if national_h5_handle is not None: - national_h5_manifest.input_identities["h5_scope_fingerprint"] = ( - _require_h5_scope_fingerprint( - national_h5_result, - scope="National", - planned_fingerprint=national_scope_fingerprint, + if isinstance(national_h5_result, dict) and national_h5_result.get( + "fingerprint" + ): + national_h5_manifest.input_identities["h5_scope_fingerprint"] = ( + national_h5_result["fingerprint"] ) - ) national_reuse_measurement = ReuseMeasurement.from_dict( national_h5_result.get("reuse_measurement", {}) if isinstance(national_h5_result, dict) diff --git a/policyengine_us_data/build_outputs/bootstrap.py b/policyengine_us_data/build_outputs/bootstrap.py index e4cd7f39a..337d07053 100644 --- a/policyengine_us_data/build_outputs/bootstrap.py +++ b/policyengine_us_data/build_outputs/bootstrap.py @@ -368,18 +368,10 @@ def build( inputs=inputs, scope=scope, ) - computed_scope_fingerprint = ( - self._fingerprinting_service.compute_scope_fingerprint(traceability) - ) - if ( - scope_fingerprint is not None - and scope_fingerprint != computed_scope_fingerprint - ): - raise ValueError( - f"Bootstrap fingerprint {scope_fingerprint!r} does not match " - f"computed {scope} fingerprint {computed_scope_fingerprint!r}" + if scope_fingerprint is None: + scope_fingerprint = self._fingerprinting_service.compute_scope_fingerprint( + traceability ) - scope_fingerprint = computed_scope_fingerprint entity_graph_path = store.entity_graph_path(scope) save_entity_graph(snapshot.entity_graph, entity_graph_path) diff --git a/tests/unit/build_outputs/test_bootstrap.py b/tests/unit/build_outputs/test_bootstrap.py index bd4c55700..3a60b4726 100644 --- a/tests/unit/build_outputs/test_bootstrap.py +++ b/tests/unit/build_outputs/test_bootstrap.py @@ -95,7 +95,7 @@ def test_worker_bootstrap_builder_persists_manifest_and_entity_graph(tmp_path): assert manifest["inputs"]["weights"]["sha256"] == "sha256:weights" -def test_worker_bootstrap_builder_accepts_matching_resolved_fingerprint(tmp_path): +def test_worker_bootstrap_builder_preserves_resolved_fingerprint_override(tmp_path): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") bundle = WorkerBootstrapBuilder( @@ -106,27 +106,11 @@ def test_worker_bootstrap_builder_accepts_matching_resolved_fingerprint(tmp_path inputs=artifacts.inputs, scope="regional", artifacts_dir=tmp_path / "artifacts", - scope_fingerprint="regional-fingerprint", + scope_fingerprint="pinned-fingerprint", ) manifest = json.loads(bundle.manifest_path.read_text()) - assert manifest["traceability"]["scope_fingerprint"] == "regional-fingerprint" - - -def test_worker_bootstrap_builder_rejects_mismatched_resolved_fingerprint(tmp_path): - artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") - - with pytest.raises(ValueError, match="Bootstrap fingerprint"): - WorkerBootstrapBuilder( - dataset_reader=FakeDatasetReader(artifacts.snapshot), - geography_loader=FakeGeographyLoader(artifacts), - fingerprinting_service=FakeFingerprintingService(), - ).build( - inputs=artifacts.inputs, - scope="regional", - artifacts_dir=tmp_path / "artifacts", - scope_fingerprint="stale-fingerprint", - ) + assert manifest["traceability"]["scope_fingerprint"] == "pinned-fingerprint" def test_worker_bootstrap_store_loads_persisted_bundle(tmp_path): diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index 3a1cb4f5c..ab3bb5ffd 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -1,8 +1,6 @@ from pathlib import Path from types import SimpleNamespace -import pytest - from tests.support.modal_local_area import load_local_area_module @@ -204,9 +202,7 @@ def compute_scope_fingerprint(self, traceability): assert seen["traceability"] == {"scope": "regional", "run_id": "run-123"} -def test_resolve_scope_fingerprint_validates_matching_expected_value( - monkeypatch, capsys -): +def test_resolve_scope_fingerprint_preserves_matching_pin(monkeypatch, capsys): local_area = load_local_area_module(stub_policyengine=False) class FakeFingerprintingService: @@ -214,7 +210,7 @@ def build_traceability(self, *, inputs, scope): return scope def compute_scope_fingerprint(self, traceability): - return "expected-fingerprint" + return "pinned-fingerprint" monkeypatch.setattr( local_area, @@ -238,15 +234,17 @@ def compute_scope_fingerprint(self, traceability): fingerprint = local_area._resolve_scope_fingerprint( inputs=bundle, scope="regional", - expected_fingerprint="expected-fingerprint", + expected_fingerprint="pinned-fingerprint", ) captured = capsys.readouterr() - assert fingerprint == "expected-fingerprint" - assert "Validated expected regional fingerprint" in captured.out + assert fingerprint == "pinned-fingerprint" + assert "Using pinned fingerprint from pipeline" in captured.out -def test_resolve_scope_fingerprint_rejects_mismatched_expected_value(monkeypatch): +def test_resolve_scope_fingerprint_warns_and_preserves_mismatched_pin( + monkeypatch, capsys +): local_area = load_local_area_module(stub_policyengine=False) class FakeFingerprintingService: @@ -275,15 +273,19 @@ def compute_scope_fingerprint(self, traceability): seed=42, ) - with pytest.raises( - RuntimeError, - match="Cannot resume national H5 build with changed inputs", - ): - local_area._resolve_scope_fingerprint( - inputs=bundle, - scope="national", - expected_fingerprint="legacy-fingerprint", - ) + fingerprint = local_area._resolve_scope_fingerprint( + inputs=bundle, + scope="national", + expected_fingerprint="legacy-fingerprint", + ) + + captured = capsys.readouterr() + assert fingerprint == "legacy-fingerprint" + assert "Pinned fingerprint differs from current national scope fingerprint" in ( + captured.out + ) + assert "legacy-fingerprint" in captured.out + assert "computed-fingerprint" in captured.out def test_build_worker_bootstrap_invokes_builder_without_changing_inputs(monkeypatch): diff --git a/tests/unit/test_pipeline_source_contracts.py b/tests/unit/test_pipeline_source_contracts.py index ee6b45de2..89bd368e4 100644 --- a/tests/unit/test_pipeline_source_contracts.py +++ b/tests/unit/test_pipeline_source_contracts.py @@ -87,18 +87,6 @@ def test_run_pipeline_refreshes_diagnostics_even_when_h5_outputs_reused() -> Non assert "Upload validation diagnostics even when H5 outputs are reused." in source -def test_run_pipeline_fails_closed_when_h5_child_fingerprint_missing() -> None: - tree = ast.parse(PIPELINE_SOURCE.read_text()) - helper = _function_def(tree, "_require_h5_scope_fingerprint") - helper_source = ast.get_source_segment(PIPELINE_SOURCE.read_text(), helper) - run_pipeline = _function_def(tree, "run_pipeline") - pipeline_source = ast.get_source_segment(PIPELINE_SOURCE.read_text(), run_pipeline) - - assert "did not include a fingerprint" in helper_source - assert "refusing to complete the H5 step manifest" in helper_source - assert pipeline_source.count("_require_h5_scope_fingerprint(") == 2 - - def test_full_release_path_combines_base_regional_and_national_outputs(): tree = ast.parse(PIPELINE_SOURCE.read_text()) helper = _function_def(tree, "_full_release_staging_rel_paths") From 0db06695cfb4e935971f249144bf3b0cfffdc167 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 21:25:57 +0200 Subject: [PATCH 09/17] Remove worker release version argument --- modal_app/worker_script.py | 7 +------ tests/unit/test_modal_worker_script.py | 1 + 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index f76f34efc..c94d1ffe8 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -152,11 +152,6 @@ def parse_args(argv: list[str] | None = None): default=None, help="Pipeline run ID used for traceability and bootstrap lookup", ) - parser.add_argument( - "--version", - default="0.0.0", - help="Package or release version associated with the worker run", - ) parser.add_argument( "--artifacts-dir", default=None, @@ -269,7 +264,7 @@ def _build_publishing_inputs(*, args, run_id: str): Path(args.run_config_path) if args.run_config_path is not None else None ), run_id=run_id, - version=args.version, + version="", n_clones=args.n_clones, seed=args.seed, ) diff --git a/tests/unit/test_modal_worker_script.py b/tests/unit/test_modal_worker_script.py index 3ddcdce99..3e9d66e9b 100644 --- a/tests/unit/test_modal_worker_script.py +++ b/tests/unit/test_modal_worker_script.py @@ -78,6 +78,7 @@ def test_parse_args_accepts_worker_session_paths(): assert args.run_id == "run-123" assert args.artifacts_dir == "/tmp/artifacts/run-123" assert args.run_config_path == "/tmp/unified_run_config.json" + assert not hasattr(args, "version") def test_load_request_inputs_from_args_uses_request_payloads_when_present(): From 352cb8e0b2af5318aa7ace51201ed633ef797d25 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 21:45:13 +0200 Subject: [PATCH 10/17] Require explicit local H5 worker scope --- modal_app/local_area.py | 5 +++ modal_app/worker_script.py | 24 ++++--------- .../test_worker_script_tiny_fixture.py | 4 +++ tests/integration/support/tiny_h5.py | 2 ++ tests/integration/test_tiny_h5_pipeline.py | 3 ++ tests/unit/test_modal_worker_script.py | 34 ++++--------------- 6 files changed, 28 insertions(+), 44 deletions(-) diff --git a/modal_app/local_area.py b/modal_app/local_area.py index c4cc88047..e9c330fb4 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -618,6 +618,7 @@ def run_phase( handle = build_areas_worker.spawn( branch=branch, run_id=run_id, + scope="regional", work_items=chunk, calibration_inputs=calibration_inputs, validate=validate, @@ -713,6 +714,7 @@ def run_phase( def build_areas_worker( branch: str, run_id: str, + scope: str, work_items: List[Dict], calibration_inputs: Dict[str, object], validate: bool = True, @@ -743,6 +745,8 @@ def build_areas_worker( str(calibration_inputs["database"]), "--output-dir", str(output_dir), + "--scope", + scope, "--run-id", run_id, "--artifacts-dir", @@ -1479,6 +1483,7 @@ def coordinate_national_publish( worker_result = build_areas_worker.remote( branch=branch, run_id=run_id, + scope="national", work_items=work_items, calibration_inputs=calibration_inputs, validate=validate, diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index c94d1ffe8..1c019a9c3 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -147,6 +147,12 @@ def parse_args(argv: list[str] | None = None): parser.add_argument("--dataset-path", required=True) parser.add_argument("--db-path", required=True) parser.add_argument("--output-dir", required=True) + parser.add_argument( + "--scope", + choices=("regional", "national"), + required=True, + help="Worker bootstrap scope to use for this request batch", + ) parser.add_argument( "--run-id", default=None, @@ -225,22 +231,6 @@ def _load_request_inputs_from_args( return "work_items", tuple(json.loads(args.work_items)) -def _infer_worker_scope(request_input_mode: str, request_inputs) -> str: - """Infer which bootstrap scope matches the queued request set.""" - - if request_input_mode == "requests": - all_national = all( - getattr(request, "area_type", None) == "national" - for request in request_inputs - ) - else: - all_national = all( - isinstance(item, dict) and item.get("type") == "national" - for item in request_inputs - ) - return "national" if all_national else "regional" - - def _build_publishing_inputs(*, args, run_id: str): """Build the traceability input bundle consumed by worker setup services.""" @@ -383,7 +373,7 @@ def main(argv: list[str] | None = None): args=args, area_build_request_cls=AreaBuildRequest, ) - scope = _infer_worker_scope(request_input_mode, request_inputs) + scope = args.scope inputs = _build_publishing_inputs(args=args, run_id=run_id) session = WorkerSessionFactory().create( diff --git a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py index 545aad94f..94af49f34 100644 --- a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py +++ b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py @@ -41,6 +41,7 @@ def _run_worker( target_config: Path | None = None, validation_config: Path | None = None, run_id: str = "tiny-worker-run", + scope: str = "regional", artifacts_dir: Path | None = None, return_process: bool = False, ) -> dict: @@ -61,6 +62,8 @@ def _run_worker( str(artifacts.db_path), "--output-dir", str(output_dir), + "--scope", + scope, "--run-id", run_id, "--run-config-path", @@ -171,6 +174,7 @@ def test_worker_builds_national_h5_from_package_geography(tmp_path): artifacts=artifacts, output_dir=output_dir, use_package_geography=True, + scope="national", ) assert result["failed"] == [] diff --git a/tests/integration/support/tiny_h5.py b/tests/integration/support/tiny_h5.py index f938017c7..7a75bd967 100644 --- a/tests/integration/support/tiny_h5.py +++ b/tests/integration/support/tiny_h5.py @@ -207,6 +207,8 @@ def run_local_h5_worker( str(artifacts.db_path), "--output-dir", str(output_dir), + "--scope", + "regional", "--n-clones", str(artifacts.n_clones), "--no-validate", diff --git a/tests/integration/test_tiny_h5_pipeline.py b/tests/integration/test_tiny_h5_pipeline.py index 5fdd81024..34954b97e 100644 --- a/tests/integration/test_tiny_h5_pipeline.py +++ b/tests/integration/test_tiny_h5_pipeline.py @@ -86,6 +86,7 @@ def test_saved_geography_h5_pipeline_builds_regional_and_national_outputs(): build_result = build.remote( branch="main", run_id=run_id, + scope="regional", work_items=_work_items("district", "state", "national"), calibration_inputs=preflight_result["calibration_inputs"], validate=False, @@ -139,6 +140,7 @@ def test_package_fallback_h5_pipeline_builds_district_output(): build_result = build.remote( branch="main", run_id=run_id, + scope="regional", work_items=_work_items("district"), calibration_inputs=preflight_result["calibration_inputs"], validate=False, @@ -176,6 +178,7 @@ def test_missing_geography_h5_pipeline_fails_clearly(): build_result = build.remote( branch="main", run_id=run_id, + scope="regional", work_items=_work_items("district"), calibration_inputs=preflight_result["calibration_inputs"], validate=False, diff --git a/tests/unit/test_modal_worker_script.py b/tests/unit/test_modal_worker_script.py index 3e9d66e9b..3d1ce8f72 100644 --- a/tests/unit/test_modal_worker_script.py +++ b/tests/unit/test_modal_worker_script.py @@ -25,6 +25,8 @@ def test_parse_args_accepts_requests_json(): "/tmp/policy_data.db", "--output-dir", "/tmp/out", + "--scope", + "regional", ] ) @@ -45,6 +47,8 @@ def test_parse_args_accepts_calibration_package_path(): "/tmp/policy_data.db", "--output-dir", "/tmp/out", + "--scope", + "regional", "--calibration-package-path", "/tmp/calibration_package.pkl", ] @@ -66,6 +70,8 @@ def test_parse_args_accepts_worker_session_paths(): "/tmp/policy_data.db", "--output-dir", "/tmp/out", + "--scope", + "national", "--run-id", "run-123", "--artifacts-dir", @@ -76,6 +82,7 @@ def test_parse_args_accepts_worker_session_paths(): ) assert args.run_id == "run-123" + assert args.scope == "national" assert args.artifacts_dir == "/tmp/artifacts/run-123" assert args.run_config_path == "/tmp/unified_run_config.json" assert not hasattr(args, "version") @@ -112,33 +119,6 @@ def test_load_request_inputs_from_args_keeps_legacy_work_items_raw(): assert work_items == ({"type": "national", "id": "US"},) -def test_infer_worker_scope_uses_national_only_for_national_bootstrap(): - assert ( - worker_script._infer_worker_scope( - "requests", - (FakeRequest(area_type="national", area_id="US"),), - ) - == "national" - ) - assert ( - worker_script._infer_worker_scope( - "requests", - ( - FakeRequest(area_type="district", area_id="NC-01"), - FakeRequest(area_type="national", area_id="US"), - ), - ) - == "regional" - ) - assert ( - worker_script._infer_worker_scope( - "work_items", - ({"type": "national", "id": "US"},), - ) - == "national" - ) - - def test_work_item_key_handles_missing_fields(): assert worker_script._work_item_key({"type": "district"}) == "district:" assert worker_script._work_item_key(["not-a-dict"]) == "unknown:" From 399fe99d2a7a4015afc1196dee8bbd3f876e106d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 12 May 2026 13:45:32 +0200 Subject: [PATCH 11/17] Align worker bootstrap integration inputs --- .../h5_worker_runtime/test_worker_script_tiny_fixture.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py index 94af49f34..ebfadf35b 100644 --- a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py +++ b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py @@ -259,6 +259,7 @@ def test_worker_consumes_scope_bootstrap_when_available(tmp_path): artifacts=artifacts, output_dir=output_dir, use_saved_geography=True, + use_package_geography=True, run_id="run-123", artifacts_dir=artifacts_dir, return_process=True, From 8f7707dafdb65daf83f3357ee019a5df3d53a9c4 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 12 May 2026 14:10:23 +0200 Subject: [PATCH 12/17] Surface successful local H5 worker logs --- modal_app/local_area.py | 4 ++- tests/unit/test_modal_local_area.py | 47 +++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/modal_app/local_area.py b/modal_app/local_area.py index e9c330fb4..40772d573 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -791,8 +791,10 @@ def build_areas_worker( env=os.environ.copy(), ) - if result.returncode != 0: + if result.stderr: print(f"Worker stderr:\n{result.stderr}", file=__import__("sys").stderr) + + if result.returncode != 0: return { "completed": [], "failed": [f"{item['type']}:{item['id']}" for item in work_items], diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index ab3bb5ffd..c8f28c67e 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -372,3 +372,50 @@ def test_build_worker_calibration_inputs_omits_missing_optional_files(tmp_path): assert "run_config" not in inputs assert "calibration_package" not in inputs + + +def test_build_areas_worker_surfaces_successful_worker_stderr( + monkeypatch, + capsys, + tmp_path, +): + local_area = load_local_area_module() + monkeypatch.setattr(local_area, "setup_gcp_credentials", lambda: None) + monkeypatch.setattr(local_area, "setup_repo", lambda branch: None) + monkeypatch.setattr(local_area, "VOLUME_MOUNT", str(tmp_path / "staging")) + monkeypatch.setattr( + local_area, + "pipeline_volume", + SimpleNamespace(reload=lambda: None), + ) + monkeypatch.setattr( + local_area, + "staging_volume", + SimpleNamespace(reload=lambda: None, commit=lambda: None), + ) + + def fake_run(cmd, **kwargs): + return SimpleNamespace( + returncode=0, + stdout='{"completed": ["district:NC-01"], "failed": [], "errors": []}', + stderr="Worker session ready: scope=regional, bootstrap=used\n", + ) + + monkeypatch.setattr(local_area.subprocess, "run", fake_run) + + result = local_area.build_areas_worker( + branch="main", + run_id="run-123", + scope="regional", + work_items=[{"type": "district", "id": "NC-01"}], + calibration_inputs={ + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + }, + validate=False, + ) + + captured = capsys.readouterr() + assert result["completed"] == ["district:NC-01"] + assert "Worker session ready: scope=regional, bootstrap=used" in captured.err From f7779971a80cc798d9b9253fd78322a1195231b2 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 12 May 2026 14:13:11 +0200 Subject: [PATCH 13/17] Simplify local H5 validation policy --- policyengine_us_data/build_outputs/validation.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/policyengine_us_data/build_outputs/validation.py b/policyengine_us_data/build_outputs/validation.py index 95c5b229b..c0d8d61a6 100644 --- a/policyengine_us_data/build_outputs/validation.py +++ b/policyengine_us_data/build_outputs/validation.py @@ -32,19 +32,9 @@ ) @dataclass(frozen=True) class ValidationPolicy: - """Validation switches for a local H5 worker session. - - The current worker uses `enabled`; the other flags make the policy shape - explicit before later migration slices move validation behavior out of the - legacy worker subprocess. - """ + """Validation switch for a local H5 worker session.""" enabled: bool = True - fail_on_exception: bool = False - fail_on_validation_failure: bool = False - run_sanity_checks: bool = True - run_target_validation: bool = True - run_national_validation: bool = True @pipeline_node( From 79fc010680adc93295258ff8888255cd9a077fe6 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 12 May 2026 15:16:45 +0200 Subject: [PATCH 14/17] Normalize H5 worker calibration inputs --- modal_app/fixtures/h5_cases.py | 29 ++- modal_app/h5_test_harness.py | 47 ++-- modal_app/local_area.py | 68 ++---- modal_app/worker_script.py | 15 +- .../build_outputs/__init__.py | 4 +- .../build_outputs/worker_inputs.py | 231 ++++++++++++++++++ .../test_worker_script_tiny_fixture.py | 46 ++-- tests/integration/support/tiny_h5.py | 46 ++-- tests/support/modal_local_area.py | 99 ++++++++ .../unit/build_outputs/test_worker_inputs.py | 146 +++++++++++ tests/unit/test_modal_local_area.py | 15 +- 11 files changed, 602 insertions(+), 144 deletions(-) create mode 100644 policyengine_us_data/build_outputs/worker_inputs.py create mode 100644 tests/unit/build_outputs/test_worker_inputs.py diff --git a/modal_app/fixtures/h5_cases.py b/modal_app/fixtures/h5_cases.py index cf06e723f..aaae1c5fe 100644 --- a/modal_app/fixtures/h5_cases.py +++ b/modal_app/fixtures/h5_cases.py @@ -9,7 +9,11 @@ import sqlite3 from dataclasses import dataclass from pathlib import Path -from typing import Any + +from policyengine_us_data.build_outputs.worker_inputs import ( + WorkerCalibrationInputPayload, + WorkerCalibrationInputs, +) FIXTURE_DATASET_PATH = Path( "/root/policyengine-us-data/tests/integration/test_fixture_50hh.h5" @@ -28,7 +32,7 @@ class SeededCase: """Description of one tiny end-to-end H5 test case.""" name: str - calibration_inputs: dict[str, Any] + calibration_inputs: WorkerCalibrationInputPayload expected_district_name: str = DISTRICT_NAME n_clones: int = N_CLONES seed: int = SEED @@ -192,13 +196,8 @@ def seed_case( weights_path = _write_weights(artifact_dir, n_records=n_records) db_path = _write_db(artifact_dir) - calibration_inputs: dict[str, Any] = { - "weights": str(weights_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": N_CLONES, - "seed": SEED, - } + geography_path = None + package_path = None if case_name == "saved_geography_success": geography_path = _write_saved_geography(artifact_dir, n_records=n_records) @@ -207,11 +206,9 @@ def seed_case( weights_path=weights_path, geography_path=geography_path, ) - calibration_inputs["geography"] = str(geography_path) elif case_name == "package_fallback_success": package_path = _write_calibration_package(artifact_dir, n_records=n_records) _write_run_config(artifact_dir, weights_path=weights_path) - calibration_inputs["calibration_package"] = str(package_path) elif case_name == "misnamed_package": _write_misnamed_package(artifact_dir, n_records=n_records) _write_run_config(artifact_dir, weights_path=weights_path) @@ -220,5 +217,13 @@ def seed_case( return SeededCase( name=case_name, - calibration_inputs=calibration_inputs, + calibration_inputs=WorkerCalibrationInputs( + weights_path=weights_path, + dataset_path=dataset_path, + database_path=db_path, + geography_path=geography_path, + calibration_package_path=package_path, + n_clones=N_CLONES, + seed=SEED, + ).to_wire_dict(), ) diff --git a/modal_app/h5_test_harness.py b/modal_app/h5_test_harness.py index cdcf27335..225dc2166 100644 --- a/modal_app/h5_test_harness.py +++ b/modal_app/h5_test_harness.py @@ -19,6 +19,10 @@ from modal_app.images import cpu_image as image # noqa: E402 from modal_app.local_area import VOLUME_MOUNT, pipeline_volume, staging_volume # noqa: E402 +from policyengine_us_data.build_outputs.worker_inputs import ( # noqa: E402 + WorkerCalibrationInputPayload, + WorkerCalibrationInputs, +) app = modal.App( @@ -70,19 +74,16 @@ def _calibration_inputs( calibration_package_path: Path | None = None, n_clones: int = 1, seed: int = 42, -) -> dict: - inputs = { - "weights": str(weights_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": n_clones, - "seed": seed, - } - if geography_path is not None: - inputs["geography"] = str(geography_path) - if calibration_package_path is not None: - inputs["calibration_package"] = str(calibration_package_path) - return inputs +) -> WorkerCalibrationInputPayload: + return WorkerCalibrationInputs( + weights_path=weights_path, + dataset_path=dataset_path, + database_path=db_path, + geography_path=geography_path, + calibration_package_path=calibration_package_path, + n_clones=n_clones, + seed=seed, + ).to_wire_dict() @app.function( @@ -290,17 +291,15 @@ def preflight_h5_case(run_id: str, *, n_clones: int = 1) -> dict: calibration_package_path=package_path if package_path.exists() else None, blocks_path=artifact_dir / "stacked_blocks.npy", ) - calibration_inputs = { - "weights": str(weights_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": n_clones, - "seed": SEED, - } - if geography_path.exists(): - calibration_inputs["geography"] = str(geography_path) - if package_path.exists(): - calibration_inputs["calibration_package"] = str(package_path) + calibration_inputs = WorkerCalibrationInputs.from_artifact_paths( + weights_path=weights_path, + dataset_path=dataset_path, + database_path=db_path, + geography_path=geography_path, + calibration_package_path=package_path, + n_clones=n_clones, + seed=SEED, + ).to_wire_dict() return { "fingerprint": fingerprint, "geography_source": resolved.kind if resolved is not None else None, diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 40772d573..a4ece84b7 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -17,7 +17,7 @@ import sys import traceback from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Mapping import modal @@ -39,6 +39,9 @@ from policyengine_us_data.build_outputs.partitioning import ( # noqa: E402 partition_weighted_work_items, ) +from policyengine_us_data.build_outputs.worker_inputs import ( # noqa: E402 + WorkerCalibrationInputs, +) from policyengine_us_data.pipeline_metadata import pipeline_node # noqa: E402 from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402 from policyengine_us_data.utils.run_context import resolve_run_id # noqa: E402 @@ -499,22 +502,19 @@ def _build_worker_calibration_inputs( seed: int, run_config_path: Path | None = None, calibration_package_path: Path | None = None, -) -> Dict[str, object]: - """Build the calibration input payload passed to H5 worker subprocesses.""" - - calibration_inputs: Dict[str, object] = { - "weights": str(weights_path), - "geography": str(geography_path), - "dataset": str(dataset_path), - "database": str(db_path), - "n_clones": n_clones, - "seed": seed, - } - if run_config_path is not None and run_config_path.exists(): - calibration_inputs["run_config"] = str(run_config_path) - if calibration_package_path is not None and calibration_package_path.exists(): - calibration_inputs["calibration_package"] = str(calibration_package_path) - return calibration_inputs +) -> WorkerCalibrationInputs: + """Build the normalized H5 worker input payload.""" + + return WorkerCalibrationInputs.from_artifact_paths( + weights_path=weights_path, + geography_path=geography_path, + dataset_path=dataset_path, + database_path=db_path, + n_clones=n_clones, + seed=seed, + run_config_path=run_config_path, + calibration_package_path=calibration_package_path, + ) @pipeline_node( @@ -589,7 +589,7 @@ def run_phase( completed: set, branch: str, run_id: str, - calibration_inputs: Dict[str, str], + calibration_inputs: WorkerCalibrationInputs | Mapping[str, object], run_dir: Path, validate: bool = True, ) -> tuple: @@ -603,6 +603,9 @@ def run_phase( """ work_chunks = partition_work(work_items, num_workers, completed) total_remaining = sum(len(c) for c in work_chunks) + worker_input_payload = WorkerCalibrationInputs.from_wire_dict( + calibration_inputs + ).to_wire_dict() print(f"\n--- Phase: {phase_name} ---") print(f"Remaining work: {total_remaining} items across {len(work_chunks)} workers") @@ -620,7 +623,7 @@ def run_phase( run_id=run_id, scope="regional", work_items=chunk, - calibration_inputs=calibration_inputs, + calibration_inputs=worker_input_payload, validate=validate, ) print(f" → fc: {handle.object_id}") @@ -716,7 +719,7 @@ def build_areas_worker( run_id: str, scope: str, work_items: List[Dict], - calibration_inputs: Dict[str, object], + calibration_inputs: WorkerCalibrationInputs | Mapping[str, object], validate: bool = True, ) -> Dict: """ @@ -732,17 +735,13 @@ def build_areas_worker( output_dir.mkdir(parents=True, exist_ok=True) work_items_json = json.dumps(work_items) + worker_inputs = WorkerCalibrationInputs.from_wire_dict(calibration_inputs) worker_cmd = [ *_python_cmd("-m", "modal_app.worker_script"), "--work-items", work_items_json, - "--weights-path", - str(calibration_inputs["weights"]), - "--dataset-path", - str(calibration_inputs["dataset"]), - "--db-path", - str(calibration_inputs["database"]), + *worker_inputs.to_worker_cli_args(), "--output-dir", str(output_dir), "--scope", @@ -752,21 +751,6 @@ def build_areas_worker( "--artifacts-dir", str(Path("/pipeline/artifacts") / run_id), ] - if "geography" in calibration_inputs: - worker_cmd.extend(["--geography-path", str(calibration_inputs["geography"])]) - if "calibration_package" in calibration_inputs: - worker_cmd.extend( - [ - "--calibration-package-path", - str(calibration_inputs["calibration_package"]), - ] - ) - if "n_clones" in calibration_inputs: - worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])]) - if "seed" in calibration_inputs: - worker_cmd.extend(["--seed", str(calibration_inputs["seed"])]) - if "run_config" in calibration_inputs: - worker_cmd.extend(["--run-config-path", str(calibration_inputs["run_config"])]) repo_root = Path("/root/policyengine-us-data") cal_dir = repo_root / "policyengine_us_data" / "calibration" worker_cmd.extend( @@ -1487,7 +1471,7 @@ def coordinate_national_publish( run_id=run_id, scope="national", work_items=work_items, - calibration_inputs=calibration_inputs, + calibration_inputs=calibration_inputs.to_wire_dict(), validate=validate, ) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index 1c019a9c3..d8e4ad21b 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -234,15 +234,15 @@ def _load_request_inputs_from_args( def _build_publishing_inputs(*, args, run_id: str): """Build the traceability input bundle consumed by worker setup services.""" - from policyengine_us_data.build_outputs.fingerprinting import ( - PublishingInputBundle, + from policyengine_us_data.build_outputs.worker_inputs import ( + WorkerCalibrationInputs, ) - return PublishingInputBundle( + worker_inputs = WorkerCalibrationInputs( weights_path=Path(args.weights_path), - source_dataset_path=Path(args.dataset_path), - target_db_path=Path(args.db_path) if args.db_path else None, - exact_geography_path=( + dataset_path=Path(args.dataset_path), + database_path=Path(args.db_path), + geography_path=( Path(args.geography_path) if args.geography_path is not None else None ), calibration_package_path=( @@ -253,11 +253,10 @@ def _build_publishing_inputs(*, args, run_id: str): run_config_path=( Path(args.run_config_path) if args.run_config_path is not None else None ), - run_id=run_id, - version="", n_clones=args.n_clones, seed=args.seed, ) + return worker_inputs.to_publishing_input_bundle(run_id=run_id) def _build_kwargs_from_request(request) -> dict[str, Any]: diff --git a/policyengine_us_data/build_outputs/__init__.py b/policyengine_us_data/build_outputs/__init__.py index e0eda9a09..6f2ec1c4c 100644 --- a/policyengine_us_data/build_outputs/__init__.py +++ b/policyengine_us_data/build_outputs/__init__.py @@ -4,6 +4,6 @@ seams rather than speculative placeholders. The current early slices support H5 output request construction, exact calibration geography loading, fingerprinting, clone-weight shape contracts, worker partitioning, source -dataset snapshot contracts, introduced worker-bootstrap artifacts, and -worker-scoped session and validation context setup. +dataset snapshot contracts, worker input normalization, worker-bootstrap +artifacts, and worker-scoped session and validation context setup. """ diff --git a/policyengine_us_data/build_outputs/worker_inputs.py b/policyengine_us_data/build_outputs/worker_inputs.py new file mode 100644 index 000000000..c0896597b --- /dev/null +++ b/policyengine_us_data/build_outputs/worker_inputs.py @@ -0,0 +1,231 @@ +"""Normalized input payloads for local H5 worker execution.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Mapping, TypeAlias + +from policyengine_us_data.pipeline_metadata import pipeline_node + +from .fingerprinting import PublishingInputBundle + +WorkerCalibrationInputValue: TypeAlias = str | int +WorkerCalibrationInputPayload: TypeAlias = dict[ + str, + WorkerCalibrationInputValue, +] + + +def _coerce_path(value: object, *, field_name: str) -> Path: + """Return a path from a wire value or raise a clear contract error.""" + + if isinstance(value, Path): + return value + if isinstance(value, str): + return Path(value) + raise TypeError(f"{field_name} must be a path string, got {type(value).__name__}") + + +def _coerce_optional_path(value: object, *, field_name: str) -> Path | None: + """Return an optional path from a wire value.""" + + if value is None: + return None + return _coerce_path(value, field_name=field_name) + + +def _coerce_int(value: object, *, field_name: str) -> int: + """Return an integer from a wire value or raise a clear contract error.""" + + if isinstance(value, bool): + raise TypeError(f"{field_name} must be an int, got bool") + if isinstance(value, int): + return value + if isinstance(value, str): + return int(value) + raise TypeError(f"{field_name} must be an int, got {type(value).__name__}") + + +@pipeline_node( + id="local_h5_worker_calibration_inputs", + label="WorkerCalibrationInputs", + node_type="library", + description="Normalized worker-execution input payload for local H5 builds.", + source_file="policyengine_us_data/build_outputs/worker_inputs.py", + status="current", + stability="moving", + pathways=["local_h5"], + validation_commands=[ + "uv run pytest tests/unit/build_outputs/test_worker_inputs.py" + ], +) +@dataclass(frozen=True) +class WorkerCalibrationInputs: + """Input artifact paths and runtime settings for one H5 worker batch. + + This is the typed library contract. Modal entrypoints may still exchange the + `to_wire_dict()` representation because it is easier to serialize and + inspect, but all producers and consumers should normalize through this + class before using field values. + """ + + weights_path: Path + dataset_path: Path + database_path: Path + geography_path: Path | None = None + calibration_package_path: Path | None = None + run_config_path: Path | None = None + n_clones: int = 430 + seed: int = 42 + + @classmethod + def from_artifact_paths( + cls, + *, + weights_path: Path, + dataset_path: Path, + database_path: Path, + geography_path: Path | None = None, + calibration_package_path: Path | None = None, + run_config_path: Path | None = None, + n_clones: int = 430, + seed: int = 42, + require_optional_paths_exist: bool = True, + ) -> "WorkerCalibrationInputs": + """Build worker inputs from coordinator artifact paths. + + Optional paths are included only when present by default, matching the + previous coordinator behavior for run configs and calibration packages. + """ + + if require_optional_paths_exist: + geography_path = geography_path if _exists(geography_path) else None + calibration_package_path = ( + calibration_package_path if _exists(calibration_package_path) else None + ) + run_config_path = run_config_path if _exists(run_config_path) else None + + return cls( + weights_path=weights_path, + dataset_path=dataset_path, + database_path=database_path, + geography_path=geography_path, + calibration_package_path=calibration_package_path, + run_config_path=run_config_path, + n_clones=n_clones, + seed=seed, + ) + + @classmethod + def from_wire_dict( + cls, + payload: Mapping[str, object] | "WorkerCalibrationInputs", + ) -> "WorkerCalibrationInputs": + """Normalize a Modal-safe worker input payload.""" + + if isinstance(payload, cls): + return payload + + missing = [ + key for key in ("weights", "dataset", "database") if key not in payload + ] + if missing: + raise KeyError( + "Missing required worker calibration input(s): " + ", ".join(missing) + ) + + return cls( + weights_path=_coerce_path(payload["weights"], field_name="weights"), + dataset_path=_coerce_path(payload["dataset"], field_name="dataset"), + database_path=_coerce_path(payload["database"], field_name="database"), + geography_path=_coerce_optional_path( + payload.get("geography"), + field_name="geography", + ), + calibration_package_path=_coerce_optional_path( + payload.get("calibration_package"), + field_name="calibration_package", + ), + run_config_path=_coerce_optional_path( + payload.get("run_config"), + field_name="run_config", + ), + n_clones=_coerce_int(payload.get("n_clones", 430), field_name="n_clones"), + seed=_coerce_int(payload.get("seed", 42), field_name="seed"), + ) + + def to_wire_dict(self) -> WorkerCalibrationInputPayload: + """Return the Modal-safe payload used by remote worker entrypoints.""" + + payload: WorkerCalibrationInputPayload = { + "weights": str(self.weights_path), + "dataset": str(self.dataset_path), + "database": str(self.database_path), + "n_clones": self.n_clones, + "seed": self.seed, + } + if self.geography_path is not None: + payload["geography"] = str(self.geography_path) + if self.calibration_package_path is not None: + payload["calibration_package"] = str(self.calibration_package_path) + if self.run_config_path is not None: + payload["run_config"] = str(self.run_config_path) + return payload + + def to_worker_cli_args(self) -> list[str]: + """Return worker_script CLI arguments for these inputs.""" + + args = [ + "--weights-path", + str(self.weights_path), + "--dataset-path", + str(self.dataset_path), + "--db-path", + str(self.database_path), + "--n-clones", + str(self.n_clones), + "--seed", + str(self.seed), + ] + if self.geography_path is not None: + args.extend(["--geography-path", str(self.geography_path)]) + if self.calibration_package_path is not None: + args.extend( + [ + "--calibration-package-path", + str(self.calibration_package_path), + ] + ) + if self.run_config_path is not None: + args.extend(["--run-config-path", str(self.run_config_path)]) + return args + + def to_publishing_input_bundle( + self, + *, + run_id: str, + version: str = "", + legacy_blocks_path: Path | None = None, + ) -> PublishingInputBundle: + """Return the traceability/fingerprinting input bundle.""" + + return PublishingInputBundle( + weights_path=self.weights_path, + source_dataset_path=self.dataset_path, + target_db_path=self.database_path, + exact_geography_path=self.geography_path, + calibration_package_path=self.calibration_package_path, + run_config_path=self.run_config_path, + run_id=run_id, + version=version, + n_clones=self.n_clones, + seed=self.seed, + legacy_blocks_path=legacy_blocks_path, + ) + + +def _exists(path: Path | None) -> bool: + """Return whether an optional artifact path exists.""" + + return path is not None and path.exists() diff --git a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py index ebfadf35b..4834d1068 100644 --- a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py +++ b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py @@ -9,12 +9,12 @@ import pytest from policyengine_us_data.build_outputs.bootstrap import WorkerBootstrapBuilder -from policyengine_us_data.build_outputs.fingerprinting import PublishingInputBundle from policyengine_us_data.build_outputs.source_dataset import ( DEFAULT_SUBENTITIES, PolicyEngineDatasetReader, ) from policyengine_us_data.build_outputs.weights import CloneWeightMatrix +from policyengine_us_data.build_outputs.worker_inputs import WorkerCalibrationInputs from tests.integration.build_outputs.fixtures import ( build_request, seed_local_h5_artifacts, @@ -48,28 +48,31 @@ def _run_worker( _require_worker_dependencies() if not isinstance(requests, (list, tuple)): requests = (requests,) + worker_inputs = WorkerCalibrationInputs( + weights_path=artifacts.weights_path, + dataset_path=artifacts.dataset_path, + database_path=artifacts.db_path, + geography_path=artifacts.geography_path if use_saved_geography else None, + calibration_package_path=( + artifacts.calibration_package_path if use_package_geography else None + ), + run_config_path=artifacts.run_config_path, + n_clones=artifacts.n_clones, + seed=42, + ) cmd = [ sys.executable, "-m", "modal_app.worker_script", "--requests-json", json.dumps([request.to_dict() for request in requests]), - "--weights-path", - str(artifacts.weights_path), - "--dataset-path", - str(artifacts.dataset_path), - "--db-path", - str(artifacts.db_path), + *worker_inputs.to_worker_cli_args(), "--output-dir", str(output_dir), "--scope", scope, "--run-id", run_id, - "--run-config-path", - str(artifacts.run_config_path), - "--n-clones", - str(artifacts.n_clones), ] if artifacts_dir is not None: cmd.extend(["--artifacts-dir", str(artifacts_dir)]) @@ -79,15 +82,6 @@ def _run_worker( cmd.extend(["--target-config", str(target_config)]) if validation_config is not None: cmd.extend(["--validation-config", str(validation_config)]) - if use_saved_geography: - cmd.extend(["--geography-path", str(artifacts.geography_path)]) - if use_package_geography: - cmd.extend( - [ - "--calibration-package-path", - str(artifacts.calibration_package_path), - ] - ) result = subprocess.run( cmd, @@ -236,18 +230,16 @@ def test_worker_consumes_scope_bootstrap_when_available(tmp_path): request = build_request("district", geography=artifacts.geography) output_dir = tmp_path / "bootstrap-out" artifacts_dir = tmp_path / "pipeline-artifacts" / "run-123" - inputs = PublishingInputBundle( + inputs = WorkerCalibrationInputs( weights_path=artifacts.weights_path, - source_dataset_path=artifacts.dataset_path, - target_db_path=artifacts.db_path, - exact_geography_path=artifacts.geography_path, + dataset_path=artifacts.dataset_path, + database_path=artifacts.db_path, + geography_path=artifacts.geography_path, calibration_package_path=artifacts.calibration_package_path, run_config_path=artifacts.run_config_path, - run_id="run-123", - version="0.0.0", n_clones=artifacts.n_clones, seed=42, - ) + ).to_publishing_input_bundle(run_id="run-123", version="0.0.0") WorkerBootstrapBuilder().build( inputs=inputs, scope="regional", diff --git a/tests/integration/support/tiny_h5.py b/tests/integration/support/tiny_h5.py index 7a75bd967..9ae3879fe 100644 --- a/tests/integration/support/tiny_h5.py +++ b/tests/integration/support/tiny_h5.py @@ -23,6 +23,9 @@ AreaBuildRequest, AreaFilter, ) +from policyengine_us_data.build_outputs.worker_inputs import ( + WorkerCalibrationInputs, +) from tests.integration.support.pipeline_workspace import TinyPipelineWorkspace from tests.integration.support.tiny_pipeline import TinyPipelineArtifacts @@ -167,20 +170,22 @@ def build_publishing_input_bundle( ) -> PublishingInputBundle: """Build the same traceability input shape used by local H5 publication.""" - return PublishingInputBundle( + worker_inputs = WorkerCalibrationInputs( weights_path=artifacts.weights_path, - source_dataset_path=artifacts.dataset_path, - target_db_path=artifacts.db_path, - exact_geography_path=artifacts.geography_path, + dataset_path=artifacts.dataset_path, + database_path=artifacts.db_path, + geography_path=artifacts.geography_path, calibration_package_path=( artifacts.calibration_package_path if scope == "regional" else None ), run_config_path=artifacts.run_config_path, - run_id=run_id, - version=VERSION, n_clones=artifacts.n_clones, seed=SEED, ) + return worker_inputs.to_publishing_input_bundle( + run_id=run_id, + version=VERSION, + ) def run_local_h5_worker( @@ -193,35 +198,30 @@ def run_local_h5_worker( ) -> dict: """Run the real local H5 worker subprocess for tiny fixture requests.""" + worker_inputs = WorkerCalibrationInputs( + weights_path=artifacts.weights_path, + dataset_path=artifacts.dataset_path, + database_path=artifacts.db_path, + geography_path=artifacts.geography_path if use_saved_geography else None, + calibration_package_path=( + artifacts.calibration_package_path if use_package_geography else None + ), + n_clones=artifacts.n_clones, + seed=SEED, + ) cmd = [ sys.executable, "-m", "modal_app.worker_script", "--requests-json", json.dumps([request.to_dict() for request in requests]), - "--weights-path", - str(artifacts.weights_path), - "--dataset-path", - str(artifacts.dataset_path), - "--db-path", - str(artifacts.db_path), + *worker_inputs.to_worker_cli_args(), "--output-dir", str(output_dir), "--scope", "regional", - "--n-clones", - str(artifacts.n_clones), "--no-validate", ] - if use_saved_geography: - cmd.extend(["--geography-path", str(artifacts.geography_path)]) - if use_package_geography: - cmd.extend( - [ - "--calibration-package-path", - str(artifacts.calibration_package_path), - ] - ) result = subprocess.run( cmd, diff --git a/tests/support/modal_local_area.py b/tests/support/modal_local_area.py index 9e32d9e96..0e5f2a426 100644 --- a/tests/support/modal_local_area.py +++ b/tests/support/modal_local_area.py @@ -83,6 +83,9 @@ def decorator(func): fake_fingerprinting = ModuleType( "policyengine_us_data.build_outputs.fingerprinting" ) + fake_worker_inputs = ModuleType( + "policyengine_us_data.build_outputs.worker_inputs" + ) fake_policyengine.__path__ = [] fake_calibration.__path__ = [] fake_build_outputs.__path__ = [] @@ -114,6 +117,99 @@ def build(self, *args, **kwargs): fake_bootstrap.WorkerBootstrapBuilder = _FakeWorkerBootstrapBuilder fake_fingerprinting.PublishingInputBundle = object + class _FakeWorkerCalibrationInputs: + def __init__( + self, + *, + weights_path, + dataset_path, + database_path, + geography_path=None, + calibration_package_path=None, + run_config_path=None, + n_clones=430, + seed=42, + ): + self.weights_path = weights_path + self.dataset_path = dataset_path + self.database_path = database_path + self.geography_path = geography_path + self.calibration_package_path = calibration_package_path + self.run_config_path = run_config_path + self.n_clones = n_clones + self.seed = seed + + @classmethod + def from_artifact_paths(cls, **kwargs): + for key in ( + "geography_path", + "calibration_package_path", + "run_config_path", + ): + path = kwargs.get(key) + if path is not None and not path.exists(): + kwargs[key] = None + return cls(**kwargs) + + @classmethod + def from_wire_dict(cls, payload): + if isinstance(payload, cls): + return payload + return cls( + weights_path=payload["weights"], + dataset_path=payload["dataset"], + database_path=payload["database"], + geography_path=payload.get("geography"), + calibration_package_path=payload.get("calibration_package"), + run_config_path=payload.get("run_config"), + n_clones=payload.get("n_clones", 430), + seed=payload.get("seed", 42), + ) + + def to_worker_cli_args(self): + args = [ + "--weights-path", + str(self.weights_path), + "--dataset-path", + str(self.dataset_path), + "--db-path", + str(self.database_path), + "--n-clones", + str(self.n_clones), + "--seed", + str(self.seed), + ] + if self.geography_path is not None: + args.extend(["--geography-path", str(self.geography_path)]) + if self.calibration_package_path is not None: + args.extend( + [ + "--calibration-package-path", + str(self.calibration_package_path), + ] + ) + if self.run_config_path is not None: + args.extend(["--run-config-path", str(self.run_config_path)]) + return args + + def to_wire_dict(self): + payload = { + "weights": str(self.weights_path), + "dataset": str(self.dataset_path), + "database": str(self.database_path), + "n_clones": self.n_clones, + "seed": self.seed, + } + if self.geography_path is not None: + payload["geography"] = str(self.geography_path) + if self.calibration_package_path is not None: + payload["calibration_package"] = str(self.calibration_package_path) + if self.run_config_path is not None: + payload["run_config"] = str(self.run_config_path) + return payload + + fake_worker_inputs.WorkerCalibrationInputs = _FakeWorkerCalibrationInputs + class _FakeFingerprintingService: def build_traceability(self, *args, **kwargs): return object() @@ -136,6 +232,9 @@ def compute_scope_fingerprint(self, *args, **kwargs): fake_fingerprinting ), "policyengine_us_data.build_outputs.partitioning": (fake_partitioning), + "policyengine_us_data.build_outputs.worker_inputs": ( + fake_worker_inputs + ), } ) diff --git a/tests/unit/build_outputs/test_worker_inputs.py b/tests/unit/build_outputs/test_worker_inputs.py new file mode 100644 index 000000000..3bdc6e9fd --- /dev/null +++ b/tests/unit/build_outputs/test_worker_inputs.py @@ -0,0 +1,146 @@ +from pathlib import Path + +import pytest + +from policyengine_us_data.build_outputs.worker_inputs import ( + WorkerCalibrationInputs, +) + + +def test_worker_calibration_inputs_round_trip_wire_payload(): + inputs = WorkerCalibrationInputs( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + database_path=Path("/tmp/policy_data.db"), + geography_path=Path("/tmp/geography_assignment.npz"), + calibration_package_path=Path("/tmp/calibration_package.pkl"), + run_config_path=Path("/tmp/unified_run_config.json"), + n_clones=4, + seed=123, + ) + + payload = inputs.to_wire_dict() + normalized = WorkerCalibrationInputs.from_wire_dict(payload) + + assert normalized == inputs + assert payload == { + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + "geography": "/tmp/geography_assignment.npz", + "calibration_package": "/tmp/calibration_package.pkl", + "run_config": "/tmp/unified_run_config.json", + "n_clones": 4, + "seed": 123, + } + + +def test_worker_calibration_inputs_defaults_legacy_modal_payload_values(): + inputs = WorkerCalibrationInputs.from_wire_dict( + { + "weights": "/tmp/calibration_weights.npy", + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + } + ) + + assert inputs.n_clones == 430 + assert inputs.seed == 42 + + +def test_worker_calibration_inputs_build_worker_cli_args(): + inputs = WorkerCalibrationInputs( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + database_path=Path("/tmp/policy_data.db"), + geography_path=Path("/tmp/geography_assignment.npz"), + calibration_package_path=Path("/tmp/calibration_package.pkl"), + run_config_path=Path("/tmp/unified_run_config.json"), + n_clones=4, + seed=123, + ) + + assert inputs.to_worker_cli_args() == [ + "--weights-path", + "/tmp/calibration_weights.npy", + "--dataset-path", + "/tmp/source.h5", + "--db-path", + "/tmp/policy_data.db", + "--n-clones", + "4", + "--seed", + "123", + "--geography-path", + "/tmp/geography_assignment.npz", + "--calibration-package-path", + "/tmp/calibration_package.pkl", + "--run-config-path", + "/tmp/unified_run_config.json", + ] + + +def test_worker_calibration_inputs_build_publishing_input_bundle(): + inputs = WorkerCalibrationInputs( + weights_path=Path("/tmp/calibration_weights.npy"), + dataset_path=Path("/tmp/source.h5"), + database_path=Path("/tmp/policy_data.db"), + geography_path=Path("/tmp/geography_assignment.npz"), + calibration_package_path=Path("/tmp/calibration_package.pkl"), + run_config_path=Path("/tmp/unified_run_config.json"), + n_clones=4, + seed=123, + ) + + bundle = inputs.to_publishing_input_bundle( + run_id="run-123", + version="1.2.3", + legacy_blocks_path=Path("/tmp/stacked_blocks.npy"), + ) + + assert bundle.weights_path == inputs.weights_path + assert bundle.source_dataset_path == inputs.dataset_path + assert bundle.target_db_path == inputs.database_path + assert bundle.exact_geography_path == inputs.geography_path + assert bundle.calibration_package_path == inputs.calibration_package_path + assert bundle.run_config_path == inputs.run_config_path + assert bundle.run_id == "run-123" + assert bundle.version == "1.2.3" + assert bundle.n_clones == 4 + assert bundle.seed == 123 + assert bundle.legacy_blocks_path == Path("/tmp/stacked_blocks.npy") + + +def test_worker_calibration_inputs_omit_missing_optional_artifact_paths(tmp_path): + run_config_path = tmp_path / "unified_run_config.json" + run_config_path.write_text("{}") + + inputs = WorkerCalibrationInputs.from_artifact_paths( + weights_path=tmp_path / "calibration_weights.npy", + dataset_path=tmp_path / "source.h5", + database_path=tmp_path / "policy_data.db", + geography_path=tmp_path / "missing_geography_assignment.npz", + calibration_package_path=tmp_path / "missing_calibration_package.pkl", + run_config_path=run_config_path, + n_clones=4, + seed=123, + ) + + payload = inputs.to_wire_dict() + + assert inputs.geography_path is None + assert inputs.calibration_package_path is None + assert inputs.run_config_path == run_config_path + assert "geography" not in payload + assert "calibration_package" not in payload + assert payload["run_config"] == str(run_config_path) + + +def test_worker_calibration_inputs_reject_missing_required_paths(): + with pytest.raises(KeyError, match="weights"): + WorkerCalibrationInputs.from_wire_dict( + { + "dataset": "/tmp/source.h5", + "database": "/tmp/policy_data.db", + } + ) diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index c8f28c67e..b9351e7f1 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -350,10 +350,11 @@ def test_build_worker_calibration_inputs_includes_existing_run_config_and_packag calibration_package_path=package_path, ) - assert inputs["run_config"] == str(run_config_path) - assert inputs["calibration_package"] == str(package_path) - assert inputs["n_clones"] == 430 - assert inputs["seed"] == 42 + assert inputs.run_config_path == run_config_path + assert inputs.calibration_package_path == package_path + assert inputs.n_clones == 430 + assert inputs.seed == 42 + assert inputs.to_wire_dict()["run_config"] == str(run_config_path) def test_build_worker_calibration_inputs_omits_missing_optional_files(tmp_path): @@ -370,8 +371,10 @@ def test_build_worker_calibration_inputs_omits_missing_optional_files(tmp_path): calibration_package_path=tmp_path / "missing_package.pkl", ) - assert "run_config" not in inputs - assert "calibration_package" not in inputs + assert inputs.run_config_path is None + assert inputs.calibration_package_path is None + assert "run_config" not in inputs.to_wire_dict() + assert "calibration_package" not in inputs.to_wire_dict() def test_build_areas_worker_surfaces_successful_worker_stderr( From 70f359080b7f36f785b303c45b056e6f5f915939 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 12 May 2026 16:03:30 +0200 Subject: [PATCH 15/17] Use cheap H5 worker bootstrap validation --- modal_app/local_area.py | 7 ++ modal_app/worker_script.py | 6 ++ .../build_outputs/worker_session.py | 91 ++++--------------- .../test_worker_script_tiny_fixture.py | 5 + .../unit/build_outputs/test_worker_session.py | 57 ++++++++---- tests/unit/test_modal_local_area.py | 5 + tests/unit/test_modal_worker_script.py | 3 + 7 files changed, 84 insertions(+), 90 deletions(-) diff --git a/modal_app/local_area.py b/modal_app/local_area.py index a4ece84b7..89ce90b09 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -592,6 +592,7 @@ def run_phase( calibration_inputs: WorkerCalibrationInputs | Mapping[str, object], run_dir: Path, validate: bool = True, + scope_fingerprint: str | None = None, ) -> tuple: """Run a single build phase, spawning workers and collecting results. @@ -625,6 +626,7 @@ def run_phase( work_items=chunk, calibration_inputs=worker_input_payload, validate=validate, + scope_fingerprint=scope_fingerprint, ) print(f" → fc: {handle.object_id}") handles.append(handle) @@ -721,6 +723,7 @@ def build_areas_worker( work_items: List[Dict], calibration_inputs: WorkerCalibrationInputs | Mapping[str, object], validate: bool = True, + scope_fingerprint: str | None = None, ) -> Dict: """ Worker function that builds a subset of H5 files. @@ -751,6 +754,8 @@ def build_areas_worker( "--artifacts-dir", str(Path("/pipeline/artifacts") / run_id), ] + if scope_fingerprint: + worker_cmd.extend(["--scope-fingerprint", scope_fingerprint]) repo_root = Path("/root/policyengine-us-data") cal_dir = repo_root / "policyengine_us_data" / "calibration" worker_cmd.extend( @@ -1239,6 +1244,7 @@ def coordinate_publish( calibration_inputs=calibration_inputs, run_dir=run_dir, validate=validate, + scope_fingerprint=fingerprint, ) accumulated_errors = [] @@ -1473,6 +1479,7 @@ def coordinate_national_publish( work_items=work_items, calibration_inputs=calibration_inputs.to_wire_dict(), validate=validate, + scope_fingerprint=fingerprint, ) print( diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index d8e4ad21b..59350a1f3 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -168,6 +168,11 @@ def parse_args(argv: list[str] | None = None): default=None, help="Optional unified run configuration JSON used for traceability", ) + parser.add_argument( + "--scope-fingerprint", + default=None, + help="Coordinator-resolved scope fingerprint expected by bootstrap artifacts", + ) parser.add_argument( "--geography-path", default=None, @@ -385,6 +390,7 @@ def main(argv: list[str] | None = None): Path(args.validation_config) if args.validation_config else None ), artifacts_dir=Path(args.artifacts_dir) if args.artifacts_dir else None, + expected_scope_fingerprint=args.scope_fingerprint, ) weights = session.weights.values n_records = session.weights.n_records diff --git a/policyengine_us_data/build_outputs/worker_session.py b/policyengine_us_data/build_outputs/worker_session.py index 355833293..855fe62e9 100644 --- a/policyengine_us_data/build_outputs/worker_session.py +++ b/policyengine_us_data/build_outputs/worker_session.py @@ -16,7 +16,7 @@ WorkerBootstrapStore, load_entity_graph, ) -from .fingerprinting import FingerprintingService, PublishingInputBundle +from .fingerprinting import PublishingInputBundle from .geography_loader import CalibrationGeographyLoader from .source_dataset import PolicyEngineDatasetReader, SourceDatasetSnapshot from .validation import AreaValidationService, ValidationContext, ValidationPolicy @@ -97,7 +97,6 @@ def __init__( dataset_reader: PolicyEngineDatasetReader | None = None, geography_loader: CalibrationGeographyLoader | None = None, validation_service: AreaValidationService | None = None, - fingerprinting_service: FingerprintingService | None = None, bootstrap_store: WorkerBootstrapStore | None = None, ) -> None: """Create a session factory with injectable seams for tests.""" @@ -105,9 +104,6 @@ def __init__( self._dataset_reader = dataset_reader or PolicyEngineDatasetReader() self._geography_loader = geography_loader or CalibrationGeographyLoader() self._validation_service = validation_service or AreaValidationService() - self._fingerprinting_service = fingerprinting_service or FingerprintingService( - geography_loader=self._geography_loader - ) self._bootstrap_store = bootstrap_store def create( @@ -120,6 +116,7 @@ def create( target_config_path: Path | None = None, validation_config_path: Path | None = None, artifacts_dir: Path | None = None, + expected_scope_fingerprint: str | None = None, ) -> WorkerSession: """Create a worker session for one local H5 scope. @@ -141,6 +138,7 @@ def create( bundle=bundle, inputs=inputs, scope=scope, + expected_scope_fingerprint=expected_scope_fingerprint, ) if bootstrap_error is not None: bundle = None @@ -216,12 +214,14 @@ def _validate_bootstrap_bundle( bundle: WorkerBootstrapBundle, inputs: PublishingInputBundle, scope: BootstrapScope, + expected_scope_fingerprint: str | None, ) -> Exception | None: try: self._raise_for_bootstrap_mismatch( bundle=bundle, inputs=inputs, scope=scope, + expected_scope_fingerprint=expected_scope_fingerprint, ) except Exception as exc: return exc @@ -233,6 +233,7 @@ def _raise_for_bootstrap_mismatch( bundle: WorkerBootstrapBundle, inputs: PublishingInputBundle, scope: BootstrapScope, + expected_scope_fingerprint: str | None, ) -> None: if bundle.run_id != inputs.run_id: raise ValueError( @@ -245,38 +246,22 @@ def _raise_for_bootstrap_mismatch( f"worker scope {scope!r}" ) - traceability = self._fingerprinting_service.build_traceability( - inputs=inputs, - scope=scope, - ) - current_inputs = { - "weights": _artifact_identity_manifest(traceability.weights), - "source_dataset": _artifact_identity_manifest(traceability.source_dataset), - "exact_geography": _artifact_identity_manifest( - traceability.exact_geography - ), - "target_db": _artifact_identity_manifest(traceability.target_db), - "calibration_package": _artifact_identity_manifest( - traceability.calibration_package - ), - "run_config": _artifact_identity_manifest(traceability.run_config), - } - - for logical_name, current_identity in current_inputs.items(): - _assert_manifest_identity_matches( - logical_name=logical_name, - expected=current_identity, - actual=bundle.inputs.get(logical_name), + if not bundle.entity_graph_path.exists(): + raise FileNotFoundError( + f"Bootstrap entity graph not found: {bundle.entity_graph_path}" ) - expected_fingerprint = self._fingerprinting_service.compute_scope_fingerprint( - traceability - ) - actual_fingerprint = bundle.traceability.get("scope_fingerprint") - if actual_fingerprint != expected_fingerprint: + actual_scope_fingerprint = bundle.traceability.get("scope_fingerprint") + if expected_scope_fingerprint is None: + raise ValueError( + "Bootstrap scope fingerprint cannot be validated without an " + "expected scope fingerprint" + ) + + if actual_scope_fingerprint != expected_scope_fingerprint: raise ValueError( - f"Bootstrap scope fingerprint {actual_fingerprint!r} does not " - f"match current fingerprint {expected_fingerprint!r}" + f"Bootstrap scope fingerprint {actual_scope_fingerprint!r} " + f"does not match expected fingerprint {expected_scope_fingerprint!r}" ) def _load_source( @@ -324,41 +309,3 @@ def _load_weights( f"expected {inputs.n_clones}" ) return weights - - -def _artifact_identity_manifest(identity) -> dict[str, Any] | None: - if identity is None: - return None - return { - "logical_name": identity.logical_name, - "sha256": identity.sha256, - "size_bytes": identity.size_bytes, - "metadata": dict(identity.metadata), - } - - -def _assert_manifest_identity_matches( - *, - logical_name: str, - expected: dict[str, Any] | None, - actual, -) -> None: - if expected is None or actual is None: - if expected != actual: - raise ValueError( - f"Bootstrap {logical_name} identity presence does not match " - "current inputs" - ) - return - - actual_manifest = dict(actual) - comparable_actual = { - "logical_name": actual_manifest.get("logical_name"), - "sha256": actual_manifest.get("sha256"), - "size_bytes": actual_manifest.get("size_bytes"), - "metadata": dict(actual_manifest.get("metadata") or {}), - } - if comparable_actual != expected: - raise ValueError( - f"Bootstrap {logical_name} identity does not match current inputs" - ) diff --git a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py index 4834d1068..5cfb7635e 100644 --- a/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py +++ b/tests/integration/build_outputs/h5_worker_runtime/test_worker_script_tiny_fixture.py @@ -42,6 +42,7 @@ def _run_worker( validation_config: Path | None = None, run_id: str = "tiny-worker-run", scope: str = "regional", + scope_fingerprint: str | None = None, artifacts_dir: Path | None = None, return_process: bool = False, ) -> dict: @@ -76,6 +77,8 @@ def _run_worker( ] if artifacts_dir is not None: cmd.extend(["--artifacts-dir", str(artifacts_dir)]) + if scope_fingerprint is not None: + cmd.extend(["--scope-fingerprint", scope_fingerprint]) if not validate: cmd.append("--no-validate") if target_config is not None: @@ -244,6 +247,7 @@ def test_worker_consumes_scope_bootstrap_when_available(tmp_path): inputs=inputs, scope="regional", artifacts_dir=artifacts_dir, + scope_fingerprint="regional-fingerprint", ) result = _run_worker( @@ -253,6 +257,7 @@ def test_worker_consumes_scope_bootstrap_when_available(tmp_path): use_saved_geography=True, use_package_geography=True, run_id="run-123", + scope_fingerprint="regional-fingerprint", artifacts_dir=artifacts_dir, return_process=True, ) diff --git a/tests/unit/build_outputs/test_worker_session.py b/tests/unit/build_outputs/test_worker_session.py index 922134ba6..7ff142ddb 100644 --- a/tests/unit/build_outputs/test_worker_session.py +++ b/tests/unit/build_outputs/test_worker_session.py @@ -1,6 +1,5 @@ from __future__ import annotations -from dataclasses import replace from pathlib import Path from types import SimpleNamespace @@ -75,17 +74,6 @@ def prepare_context(self, **kwargs): ) -class MismatchedWeightFingerprintingService(FakeFingerprintingService): - """Fingerprinting fake that reports a changed current weights identity.""" - - def build_traceability(self, **kwargs): - traceability = super().build_traceability(**kwargs) - return replace( - traceability, - weights=replace(traceability.weights, sha256="sha256:changed-weights"), - ) - - def test_worker_session_caches_are_per_session(tmp_path): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") @@ -154,13 +142,13 @@ def test_worker_session_factory_prefers_bootstrap_entity_graph(tmp_path): dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), - fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=artifacts.inputs, scope="regional", validation_policy=ValidationPolicy(), period=2024, + expected_scope_fingerprint="regional-fingerprint", ) assert session.bootstrap_status == "used" @@ -171,6 +159,40 @@ def test_worker_session_factory_prefers_bootstrap_entity_graph(tmp_path): ) +def test_worker_session_factory_requires_expected_fingerprint_for_bootstrap( + tmp_path, +): + artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") + store = WorkerBootstrapStore(tmp_path / "artifacts") + WorkerBootstrapBuilder( + dataset_reader=FakeDatasetReader(artifacts.snapshot), + geography_loader=FakeGeographyLoader(artifacts), + fingerprinting_service=FakeFingerprintingService(), + ).build( + inputs=artifacts.inputs, + scope="regional", + artifacts_dir=store.artifacts_dir, + ) + dataset_reader = SessionDatasetReader(artifacts.snapshot) + + session = WorkerSessionFactory( + dataset_reader=dataset_reader, + geography_loader=SessionGeographyLoader(artifacts), + validation_service=FakeValidationService(), + bootstrap_store=store, + ).create( + inputs=artifacts.inputs, + scope="regional", + validation_policy=ValidationPolicy(), + period=2024, + ) + + assert session.bootstrap_status == "fallback" + assert session.bootstrap_bundle is None + assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] + assert "expected scope fingerprint" in session.caches["bootstrap_error"] + + def test_worker_session_factory_falls_back_when_bootstrap_source_load_fails(tmp_path): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") store = WorkerBootstrapStore(tmp_path / "artifacts") @@ -192,13 +214,13 @@ def test_worker_session_factory_falls_back_when_bootstrap_source_load_fails(tmp_ dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), - fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=artifacts.inputs, scope="regional", validation_policy=ValidationPolicy(), period=2024, + expected_scope_fingerprint="regional-fingerprint", ) assert session.bootstrap_status == "fallback" @@ -238,7 +260,6 @@ def test_worker_session_factory_falls_back_when_bootstrap_inputs_mismatch(tmp_pa dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), - fingerprinting_service=FakeFingerprintingService(), bootstrap_store=store, ).create( inputs=changed_inputs, @@ -253,7 +274,7 @@ def test_worker_session_factory_falls_back_when_bootstrap_inputs_mismatch(tmp_pa assert "does not match worker run_id" in session.caches["bootstrap_error"] -def test_worker_session_factory_falls_back_when_bootstrap_artifact_identity_mismatch( +def test_worker_session_factory_falls_back_when_bootstrap_fingerprint_mismatch( tmp_path, ): artifacts = make_bootstrap_test_artifacts(tmp_path / "inputs") @@ -273,19 +294,19 @@ def test_worker_session_factory_falls_back_when_bootstrap_artifact_identity_mism dataset_reader=dataset_reader, geography_loader=SessionGeographyLoader(artifacts), validation_service=FakeValidationService(), - fingerprinting_service=MismatchedWeightFingerprintingService(), bootstrap_store=store, ).create( inputs=artifacts.inputs, scope="regional", validation_policy=ValidationPolicy(), period=2024, + expected_scope_fingerprint="changed-fingerprint", ) assert session.bootstrap_status == "fallback" assert session.bootstrap_bundle is None assert dataset_reader.loaded_paths == [artifacts.inputs.source_dataset_path] - assert "weights identity does not match" in session.caches["bootstrap_error"] + assert "does not match expected fingerprint" in session.caches["bootstrap_error"] def test_worker_session_factory_marks_corrupt_bootstrap_as_fallback(tmp_path): diff --git a/tests/unit/test_modal_local_area.py b/tests/unit/test_modal_local_area.py index b9351e7f1..da847fd24 100644 --- a/tests/unit/test_modal_local_area.py +++ b/tests/unit/test_modal_local_area.py @@ -396,8 +396,10 @@ def test_build_areas_worker_surfaces_successful_worker_stderr( "staging_volume", SimpleNamespace(reload=lambda: None, commit=lambda: None), ) + captured_cmd = {} def fake_run(cmd, **kwargs): + captured_cmd["cmd"] = cmd return SimpleNamespace( returncode=0, stdout='{"completed": ["district:NC-01"], "failed": [], "errors": []}', @@ -417,8 +419,11 @@ def fake_run(cmd, **kwargs): "database": "/tmp/policy_data.db", }, validate=False, + scope_fingerprint="regional-fingerprint", ) captured = capsys.readouterr() assert result["completed"] == ["district:NC-01"] assert "Worker session ready: scope=regional, bootstrap=used" in captured.err + assert "--scope-fingerprint" in captured_cmd["cmd"] + assert "regional-fingerprint" in captured_cmd["cmd"] diff --git a/tests/unit/test_modal_worker_script.py b/tests/unit/test_modal_worker_script.py index 3d1ce8f72..690f0a3b7 100644 --- a/tests/unit/test_modal_worker_script.py +++ b/tests/unit/test_modal_worker_script.py @@ -78,6 +78,8 @@ def test_parse_args_accepts_worker_session_paths(): "/tmp/artifacts/run-123", "--run-config-path", "/tmp/unified_run_config.json", + "--scope-fingerprint", + "regional-fingerprint", ] ) @@ -85,6 +87,7 @@ def test_parse_args_accepts_worker_session_paths(): assert args.scope == "national" assert args.artifacts_dir == "/tmp/artifacts/run-123" assert args.run_config_path == "/tmp/unified_run_config.json" + assert args.scope_fingerprint == "regional-fingerprint" assert not hasattr(args, "version") From f7f2accf553e1e1c29d41e2d55c278f5d63b21f5 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 12 May 2026 16:30:04 +0200 Subject: [PATCH 16/17] Tighten H5 worker diagnostics and scope tests --- modal_app/worker_script.py | 24 ++++++++++++++++------ tests/integration/test_tiny_h5_pipeline.py | 22 ++++++++++++++------ tests/unit/test_modal_worker_script.py | 18 ++++++++++++++++ 3 files changed, 52 insertions(+), 12 deletions(-) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index 59350a1f3..92f7b752d 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -348,6 +348,23 @@ def _resolve_request_input( return _request_key(request), request +def _log_worker_session_ready(*, scope: str, session, geography) -> None: + """Write worker-session setup details to stderr for Modal diagnostics.""" + + print( + "Worker session ready: " + f"scope={scope}, bootstrap={session.bootstrap_status}, " + f"{geography.n_clones} clones x {geography.n_records} records", + file=sys.stderr, + ) + bootstrap_error = session.caches.get("bootstrap_error") + if bootstrap_error: + print( + f"Worker bootstrap fallback reason: {bootstrap_error}", + file=sys.stderr, + ) + + def main(argv: list[str] | None = None): args = parse_args(argv) @@ -407,12 +424,7 @@ def main(argv: list[str] | None = None): constraints_map = ( validation_context.constraints_map if validation_context is not None else None ) - print( - "Worker session ready: " - f"scope={scope}, bootstrap={session.bootstrap_status}, " - f"{geography.n_clones} clones x {geography.n_records} records", - file=sys.stderr, - ) + _log_worker_session_ready(scope=scope, session=session, geography=geography) if validation_targets is not None: print( f"Validation ready: {len(validation_targets)} targets, " diff --git a/tests/integration/test_tiny_h5_pipeline.py b/tests/integration/test_tiny_h5_pipeline.py index 34954b97e..372d65553 100644 --- a/tests/integration/test_tiny_h5_pipeline.py +++ b/tests/integration/test_tiny_h5_pipeline.py @@ -83,22 +83,32 @@ def test_saved_geography_h5_pipeline_builds_regional_and_national_outputs(): assert preflight_result["geography_source"] == "saved_geography" assert len(preflight_result["fingerprint"]) == 16 - build_result = build.remote( + regional_result = build.remote( branch="main", run_id=run_id, scope="regional", - work_items=_work_items("district", "state", "national"), + work_items=_work_items("district", "state"), + calibration_inputs=preflight_result["calibration_inputs"], + validate=False, + ) + national_result = build.remote( + branch="main", + run_id=run_id, + scope="national", + work_items=_work_items("national"), calibration_inputs=preflight_result["calibration_inputs"], validate=False, ) - assert build_result["failed"] == [] - assert build_result["errors"] == [] - assert build_result["completed"] == [ + assert regional_result["failed"] == [] + assert regional_result["errors"] == [] + assert regional_result["completed"] == [ "district:NC-01", "state:NC", - "national:US", ] + assert national_result["failed"] == [] + assert national_result["errors"] == [] + assert national_result["completed"] == ["national:US"] manifest = validate.remote(branch="main", run_id=run_id, version="0.0.0") assert manifest["totals"]["districts"] == 1 diff --git a/tests/unit/test_modal_worker_script.py b/tests/unit/test_modal_worker_script.py index 690f0a3b7..2ce7d26e7 100644 --- a/tests/unit/test_modal_worker_script.py +++ b/tests/unit/test_modal_worker_script.py @@ -202,3 +202,21 @@ def test_resolve_output_path_rejects_escaped_request_path(tmp_path): assert "must stay within the worker output_dir" in str(exc) else: raise AssertionError("Expected _resolve_output_path to reject traversal") + + +def test_log_worker_session_ready_includes_bootstrap_fallback_reason(capsys): + session = SimpleNamespace( + bootstrap_status="fallback", + caches={"bootstrap_error": "entity graph load failed"}, + ) + geography = SimpleNamespace(n_clones=2, n_records=10) + + worker_script._log_worker_session_ready( + scope="regional", + session=session, + geography=geography, + ) + + captured = capsys.readouterr() + assert "Worker session ready: scope=regional, bootstrap=fallback" in captured.err + assert "Worker bootstrap fallback reason: entity graph load failed" in captured.err From 57a43889e5e9495b016a5c20ac5d42761619309c Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 12 May 2026 19:01:55 +0200 Subject: [PATCH 17/17] Move H5 area validation behind service --- modal_app/worker_script.py | 182 ++------------ .../build_outputs/validation.py | 233 ++++++++++++++++++ tests/unit/build_outputs/test_validation.py | 143 +++++++++++ 3 files changed, 397 insertions(+), 161 deletions(-) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index 92f7b752d..d89d22b1c 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -13,122 +13,6 @@ from typing import Any -def _validate_in_subprocess( - h5_path, - area_type, - area_id, - display_id, - area_targets, - area_training, - constraints_map, - db_path, - period, -): - """Run validation for one area inside a subprocess. - - All Microsimulation memory is reclaimed when the - subprocess exits. - """ - import logging - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(message)s", - ) - from policyengine_us import Microsimulation - from sqlalchemy import create_engine as _ce - from policyengine_us_data.calibration.validate_staging import ( - validate_area, - _build_variable_entity_map, - ) - - engine = _ce(f"sqlite:///{db_path}") - sim = Microsimulation(dataset=h5_path) - variable_entity_map = _build_variable_entity_map(sim) - - results = validate_area( - sim=sim, - targets_df=area_targets, - engine=engine, - area_type=area_type, - area_id=area_id, - display_id=display_id, - dataset_path=h5_path, - period=period, - training_mask=area_training, - variable_entity_map=variable_entity_map, - constraints_map=constraints_map, - ) - return results - - -def _validate_h5_subprocess( - h5_path, - request, - validation_targets, - training_mask_full, - constraints_map, - db_path, - period, -): - """Spawn a subprocess to validate one H5 file. - - Uses multiprocessing spawn to isolate memory. - """ - import multiprocessing as _mp - - geo_level = request.validation_geo_level - geographic_ids = tuple(str(item) for item in request.validation_geographic_ids) - if geo_level is None: - return [] - area_type = { - "state": "states", - "district": "districts", - "city": "cities", - "national": "national", - }.get(request.area_type) - if area_type is None: - return [] - display_id = request.display_name - - # Filter targets to matching area - if request.area_type == "national": - mask = validation_targets["geo_level"] == geo_level - else: - mask = (validation_targets["geo_level"] == geo_level) & ( - validation_targets["geographic_id"].astype(str).isin(geographic_ids) - ) - - area_targets = validation_targets[mask].reset_index(drop=True) - area_training = training_mask_full[mask.values] - - if len(area_targets) == 0: - return [] - - # Filter constraints_map to relevant strata - area_strata = area_targets["stratum_id"].unique().tolist() - area_constraints = {int(s): constraints_map.get(int(s), []) for s in area_strata} - - ctx = _mp.get_context("spawn") - with ctx.Pool(1) as pool: - results = pool.apply( - _validate_in_subprocess, - ( - h5_path, - area_type, - request.area_id, - display_id, - area_targets, - area_training, - area_constraints, - db_path, - period, - ), - ) - - return results - - def parse_args(argv: list[str] | None = None): """Parse worker arguments for legacy and typed request inputs.""" @@ -386,7 +270,10 @@ def main(argv: list[str] | None = None): ) from policyengine_us_data.build_outputs.area_catalog import USAreaCatalog from policyengine_us_data.build_outputs.requests import AreaBuildRequest - from policyengine_us_data.build_outputs.validation import ValidationPolicy + from policyengine_us_data.build_outputs.validation import ( + AreaValidationService, + ValidationPolicy, + ) from policyengine_us_data.build_outputs.worker_session import WorkerSessionFactory area_catalog = USAreaCatalog.default() @@ -396,8 +283,9 @@ def main(argv: list[str] | None = None): ) scope = args.scope inputs = _build_publishing_inputs(args=args, run_id=run_id) + validation_service = AreaValidationService() - session = WorkerSessionFactory().create( + session = WorkerSessionFactory(validation_service=validation_service).create( inputs=inputs, scope=scope, validation_policy=ValidationPolicy(enabled=not args.no_validate), @@ -413,22 +301,14 @@ def main(argv: list[str] | None = None): n_records = session.weights.n_records geography = session.geography validation_context = session.validation_context - validation_targets = ( - validation_context.validation_targets - if validation_context is not None - else None - ) - training_mask_full = ( - validation_context.training_mask if validation_context is not None else None - ) - constraints_map = ( - validation_context.constraints_map if validation_context is not None else None - ) _log_worker_session_ready(scope=scope, session=session, geography=geography) - if validation_targets is not None: + if ( + validation_context is not None + and validation_context.validation_targets is not None + ): print( - f"Validation ready: {len(validation_targets)} targets, " - f"{len(constraints_map or {})} strata", + f"Validation ready: {len(validation_context.validation_targets)} targets, " + f"{len(validation_context.constraints_map or {})} strata", file=sys.stderr, ) @@ -497,42 +377,22 @@ def main(argv: list[str] | None = None): file=sys.stderr, ) - # ── Per-item validation ── - if not args.no_validate and validation_targets is not None: + if not args.no_validate and validation_context is not None: try: - v_rows = _validate_h5_subprocess( + validation_result = validation_service.validate_request( + context=validation_context, h5_path=str(path), request=request, - validation_targets=validation_targets, - training_mask_full=training_mask_full, - constraints_map=constraints_map, - db_path=str(inputs.target_db_path), - period=args.period, ) + v_rows = list(validation_result.rows) results["validation_rows"].extend(v_rows) - n_fail = sum( - 1 for r in v_rows if r.get("sanity_check") == "FAIL" - ) - rae_vals = [ - r["rel_abs_error"] - for r in v_rows - if isinstance( - r.get("rel_abs_error"), - (int, float), - ) - and r["rel_abs_error"] != float("inf") - ] - mean_rae = sum(rae_vals) / len(rae_vals) if rae_vals else 0.0 - results["validation_summary"][request_key] = { - "n_targets": len(v_rows), - "n_sanity_fail": n_fail, - "mean_rel_abs_error": round(mean_rae, 4), - } + summary = dict(validation_result.summary) + results["validation_summary"][request_key] = summary print( f" Validated {request_key}: " - f"{len(v_rows)} targets, " - f"{n_fail} sanity fails, " - f"mean RAE={mean_rae:.4f}", + f"{summary['n_targets']} targets, " + f"{summary['n_sanity_fail']} sanity fails, " + f"mean RAE={summary['mean_rel_abs_error']:.4f}", file=sys.stderr, ) except Exception as ve: diff --git a/policyengine_us_data/build_outputs/validation.py b/policyengine_us_data/build_outputs/validation.py index c0d8d61a6..534f53045 100644 --- a/policyengine_us_data/build_outputs/validation.py +++ b/policyengine_us_data/build_outputs/validation.py @@ -15,6 +15,7 @@ __all__ = [ "AreaValidationService", "ValidationContext", + "AreaValidationResult", "ValidationPolicy", ] @@ -93,6 +94,24 @@ def __post_init__(self) -> None: ) +@dataclass(frozen=True) +class AreaValidationResult: + """Validation rows and summary for one built area H5.""" + + rows: tuple[Mapping[str, Any], ...] = () + summary: Mapping[str, Any] | None = None + + def __post_init__(self) -> None: + rows = tuple(self.rows) + summary = ( + dict(self.summary) + if self.summary is not None + else _validation_summary(rows) + ) + object.__setattr__(self, "rows", rows) + object.__setattr__(self, "summary", summary) + + @pipeline_node( id="local_h5_area_validation_service", label="AreaValidationService", @@ -116,6 +135,7 @@ def __init__( batch_constraints: Callable[[Any, list[int]], Mapping[int, Any]] | None = None, load_target_config: Callable[[Path | str], Mapping[str, Any]] | None = None, match_rules: Callable[[Any, list[Mapping[str, Any]]], Any] | None = None, + validate_h5: Callable[..., list[Mapping[str, Any]]] | None = None, ) -> None: """Create a validation service with injectable seams for tests.""" @@ -124,6 +144,7 @@ def __init__( self._batch_constraints = batch_constraints self._load_target_config = load_target_config self._match_rules = match_rules + self._validate_h5 = validate_h5 def prepare_context( self, @@ -184,6 +205,70 @@ def prepare_context( validation_config_path=validation_config_path, ) + def validate_request( + self, + *, + context: ValidationContext | None, + h5_path: Path | str, + request, + ) -> AreaValidationResult: + """Validate one built area H5 against a prepared worker context.""" + + if ( + context is None + or context.validation_targets is None + or context.target_db_path is None + ): + return AreaValidationResult() + + geo_level = request.validation_geo_level + geographic_ids = tuple(str(item) for item in request.validation_geographic_ids) + if geo_level is None: + return AreaValidationResult() + + area_type = { + "state": "states", + "district": "districts", + "city": "cities", + "national": "national", + }.get(request.area_type) + if area_type is None: + return AreaValidationResult() + + targets = context.validation_targets + if request.area_type == "national": + mask = targets["geo_level"] == geo_level + else: + mask = (targets["geo_level"] == geo_level) & ( + targets["geographic_id"].astype(str).isin(geographic_ids) + ) + + area_targets = targets[mask].reset_index(drop=True) + if len(area_targets) == 0: + return AreaValidationResult() + + area_training = self._area_training_mask(context.training_mask, mask) + area_constraints = self._area_constraints( + area_targets, + context.constraints_map or {}, + ) + rows = tuple( + self._run_validation( + h5_path=str(h5_path), + request=request, + area_type=area_type, + area_targets=area_targets, + area_training=area_training, + constraints_map=area_constraints, + db_path=str(context.target_db_path), + period=context.period, + ) + ) + return AreaValidationResult( + rows=rows, + summary=_validation_summary(rows), + ) + def _create_engine(self, target_db_path: Path): if self._engine_factory is not None: return self._engine_factory(f"sqlite:///{target_db_path}") @@ -255,3 +340,151 @@ def _training_mask(self, validation_targets, config_path: Path | None): if not include_rules: return np.ones(len(validation_targets), dtype=bool) return np.asarray(self._match(validation_targets, include_rules), dtype=bool) + + def _area_training_mask(self, training_mask, area_mask): + if training_mask is None: + return np.ones(int(np.sum(area_mask.to_numpy())), dtype=bool) + return training_mask[area_mask.values] + + def _area_constraints( + self, + area_targets, + constraints_map: Mapping[int, Any], + ) -> dict[int, Any]: + area_strata = area_targets["stratum_id"].unique().tolist() + return { + int(stratum): constraints_map.get(int(stratum), []) + for stratum in area_strata + } + + def _run_validation( + self, + *, + h5_path: str, + request, + area_type: str, + area_targets, + area_training, + constraints_map: Mapping[int, Any], + db_path: str, + period: int, + ) -> list[Mapping[str, Any]]: + if self._validate_h5 is not None: + return self._validate_h5( + h5_path=h5_path, + request=request, + area_type=area_type, + area_targets=area_targets, + area_training=area_training, + constraints_map=constraints_map, + db_path=db_path, + period=period, + ) + return _validate_h5_subprocess( + h5_path=h5_path, + request=request, + area_type=area_type, + area_targets=area_targets, + area_training=area_training, + constraints_map=constraints_map, + db_path=db_path, + period=period, + ) + + +def _validation_summary(rows: tuple[Mapping[str, Any], ...]) -> dict[str, Any]: + n_fail = sum(1 for row in rows if row.get("sanity_check") == "FAIL") + rae_values = [ + row["rel_abs_error"] + for row in rows + if isinstance(row.get("rel_abs_error"), (int, float)) + and row["rel_abs_error"] != float("inf") + ] + mean_rae = sum(rae_values) / len(rae_values) if rae_values else 0.0 + return { + "n_targets": len(rows), + "n_sanity_fail": n_fail, + "mean_rel_abs_error": round(mean_rae, 4), + } + + +def _validate_h5_subprocess( + *, + h5_path: str, + request, + area_type: str, + area_targets, + area_training, + constraints_map: Mapping[int, Any], + db_path: str, + period: int, +) -> list[Mapping[str, Any]]: + """Spawn a subprocess to validate one H5 file and reclaim simulation memory.""" + + import multiprocessing as mp + + ctx = mp.get_context("spawn") + with ctx.Pool(1) as pool: + return pool.apply( + _validate_area_in_subprocess, + ( + h5_path, + area_type, + request.area_id, + request.display_name, + area_targets, + area_training, + constraints_map, + db_path, + period, + ), + ) + + +def _validate_area_in_subprocess( + h5_path, + area_type, + area_id, + display_id, + area_targets, + area_training, + constraints_map, + db_path, + period, +): + """Run validation for one area in an isolated subprocess.""" + + import logging + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + ) + from policyengine_us import Microsimulation + from sqlalchemy import create_engine + from policyengine_us_data.calibration.validate_staging import ( + _build_variable_entity_map, + validate_area, + ) + + engine = create_engine(f"sqlite:///{db_path}") + try: + sim = Microsimulation(dataset=h5_path) + variable_entity_map = _build_variable_entity_map(sim) + return validate_area( + sim=sim, + targets_df=area_targets, + engine=engine, + area_type=area_type, + area_id=area_id, + display_id=display_id, + dataset_path=h5_path, + period=period, + training_mask=area_training, + variable_entity_map=variable_entity_map, + constraints_map=constraints_map, + ) + finally: + dispose = getattr(engine, "dispose", None) + if callable(dispose): + dispose() diff --git a/tests/unit/build_outputs/test_validation.py b/tests/unit/build_outputs/test_validation.py index f5e8c855e..3455ad96e 100644 --- a/tests/unit/build_outputs/test_validation.py +++ b/tests/unit/build_outputs/test_validation.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from types import SimpleNamespace import numpy as np import pandas as pd @@ -126,3 +127,145 @@ def match_rules(targets, rules): 2: ["constraint-2"], } assert constraint_calls == [(1, 2)] + + +def test_validation_service_validates_one_area_from_prepared_context(): + calls = [] + targets = pd.DataFrame( + { + "variable": ["household_count", "state_income", "national_income"], + "stratum_id": [1, 2, 3], + "geo_level": ["district", "state", "national"], + "geographic_id": ["3701", "37", "US"], + } + ) + context = ValidationContext( + policy=ValidationPolicy(), + target_db_path=Path("/tmp/policy_data.db"), + period=2024, + validation_targets=targets, + training_mask=np.array([True, False, True]), + constraints_map={1: ["constraint-1"], 2: ["constraint-2"], 3: ["constraint-3"]}, + ) + request = SimpleNamespace( + area_type="district", + area_id="NC-01", + display_name="NC-01", + validation_geo_level="district", + validation_geographic_ids=("3701",), + ) + + def validate_h5(**kwargs): + calls.append(kwargs) + return [ + {"sanity_check": "PASS", "rel_abs_error": 0.25}, + {"sanity_check": "FAIL", "rel_abs_error": float("inf")}, + ] + + result = AreaValidationService(validate_h5=validate_h5).validate_request( + context=context, + h5_path=Path("/tmp/NC-01.h5"), + request=request, + ) + + assert len(calls) == 1 + assert calls[0]["h5_path"] == "/tmp/NC-01.h5" + assert calls[0]["area_type"] == "districts" + assert calls[0]["area_targets"]["variable"].tolist() == ["household_count"] + assert np.array_equal(calls[0]["area_training"], np.array([True])) + assert calls[0]["constraints_map"] == {1: ["constraint-1"]} + assert calls[0]["db_path"] == "/tmp/policy_data.db" + assert calls[0]["period"] == 2024 + assert result.summary == { + "n_targets": 2, + "n_sanity_fail": 1, + "mean_rel_abs_error": 0.25, + } + + +def test_validation_service_filters_national_targets_without_geographic_id(): + calls = [] + targets = pd.DataFrame( + { + "variable": ["state_income", "national_income"], + "stratum_id": [2, 3], + "geo_level": ["state", "national"], + "geographic_id": ["37", "US"], + } + ) + context = ValidationContext( + policy=ValidationPolicy(), + target_db_path=Path("/tmp/policy_data.db"), + period=2024, + validation_targets=targets, + training_mask=np.array([False, True]), + constraints_map={2: ["constraint-2"], 3: ["constraint-3"]}, + ) + request = SimpleNamespace( + area_type="national", + area_id="US", + display_name="US", + validation_geo_level="national", + validation_geographic_ids=("ignored",), + ) + + def validate_h5(**kwargs): + calls.append(kwargs) + return [{"sanity_check": "PASS", "rel_abs_error": 0.0}] + + result = AreaValidationService(validate_h5=validate_h5).validate_request( + context=context, + h5_path=Path("/tmp/US.h5"), + request=request, + ) + + assert calls[0]["area_type"] == "national" + assert calls[0]["area_targets"]["variable"].tolist() == ["national_income"] + assert np.array_equal(calls[0]["area_training"], np.array([True])) + assert calls[0]["constraints_map"] == {3: ["constraint-3"]} + assert result.summary["n_targets"] == 1 + + +def test_validation_service_returns_empty_result_for_unmatched_area(): + called = False + context = ValidationContext( + policy=ValidationPolicy(), + target_db_path=Path("/tmp/policy_data.db"), + period=2024, + validation_targets=pd.DataFrame( + { + "variable": ["state_income"], + "stratum_id": [2], + "geo_level": ["state"], + "geographic_id": ["37"], + } + ), + training_mask=np.array([True]), + constraints_map={2: ["constraint-2"]}, + ) + request = SimpleNamespace( + area_type="district", + area_id="NC-01", + display_name="NC-01", + validation_geo_level="district", + validation_geographic_ids=("3701",), + ) + + def validate_h5(**kwargs): + nonlocal called + called = True + return [] + + result = AreaValidationService(validate_h5=validate_h5).validate_request( + context=context, + h5_path=Path("/tmp/NC-01.h5"), + request=request, + ) + + assert called is False + assert result.rows == () + assert result.summary == { + "n_targets": 0, + "n_sanity_fail": 0, + "mean_rel_abs_error": 0.0, + }