Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/951.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add worker-scoped setup and validation context contracts for local H5 builds.
29 changes: 17 additions & 12 deletions modal_app/fixtures/h5_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from policyengine_us_data.build_outputs.worker_inputs import (
WorkerCalibrationInputPayload,
WorkerCalibrationInputs,
)

FIXTURE_DATASET_PATH = Path(
"/root/policyengine-us-data/tests/integration/test_fixture_50hh.h5"
Expand All @@ -28,7 +32,7 @@ class SeededCase:
"""Description of one tiny end-to-end H5 test case."""

name: str
calibration_inputs: dict[str, Any]
calibration_inputs: WorkerCalibrationInputPayload
expected_district_name: str = DISTRICT_NAME
n_clones: int = N_CLONES
seed: int = SEED
Expand Down Expand Up @@ -192,13 +196,8 @@ def seed_case(
weights_path = _write_weights(artifact_dir, n_records=n_records)
db_path = _write_db(artifact_dir)

calibration_inputs: dict[str, Any] = {
"weights": str(weights_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": N_CLONES,
"seed": SEED,
}
geography_path = None
package_path = None

if case_name == "saved_geography_success":
geography_path = _write_saved_geography(artifact_dir, n_records=n_records)
Expand All @@ -207,11 +206,9 @@ def seed_case(
weights_path=weights_path,
geography_path=geography_path,
)
calibration_inputs["geography"] = str(geography_path)
elif case_name == "package_fallback_success":
package_path = _write_calibration_package(artifact_dir, n_records=n_records)
_write_run_config(artifact_dir, weights_path=weights_path)
calibration_inputs["calibration_package"] = str(package_path)
elif case_name == "misnamed_package":
_write_misnamed_package(artifact_dir, n_records=n_records)
_write_run_config(artifact_dir, weights_path=weights_path)
Expand All @@ -220,5 +217,13 @@ def seed_case(

return SeededCase(
name=case_name,
calibration_inputs=calibration_inputs,
calibration_inputs=WorkerCalibrationInputs(
weights_path=weights_path,
dataset_path=dataset_path,
database_path=db_path,
geography_path=geography_path,
calibration_package_path=package_path,
n_clones=N_CLONES,
seed=SEED,
).to_wire_dict(),
)
47 changes: 23 additions & 24 deletions modal_app/h5_test_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

from modal_app.images import cpu_image as image # noqa: E402
from modal_app.local_area import VOLUME_MOUNT, pipeline_volume, staging_volume # noqa: E402
from policyengine_us_data.build_outputs.worker_inputs import ( # noqa: E402
WorkerCalibrationInputPayload,
WorkerCalibrationInputs,
)


app = modal.App(
Expand Down Expand Up @@ -70,19 +74,16 @@ def _calibration_inputs(
calibration_package_path: Path | None = None,
n_clones: int = 1,
seed: int = 42,
) -> dict:
inputs = {
"weights": str(weights_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
"seed": seed,
}
if geography_path is not None:
inputs["geography"] = str(geography_path)
if calibration_package_path is not None:
inputs["calibration_package"] = str(calibration_package_path)
return inputs
) -> WorkerCalibrationInputPayload:
return WorkerCalibrationInputs(
weights_path=weights_path,
dataset_path=dataset_path,
database_path=db_path,
geography_path=geography_path,
calibration_package_path=calibration_package_path,
n_clones=n_clones,
seed=seed,
).to_wire_dict()


@app.function(
Expand Down Expand Up @@ -290,17 +291,15 @@ def preflight_h5_case(run_id: str, *, n_clones: int = 1) -> dict:
calibration_package_path=package_path if package_path.exists() else None,
blocks_path=artifact_dir / "stacked_blocks.npy",
)
calibration_inputs = {
"weights": str(weights_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
"seed": SEED,
}
if geography_path.exists():
calibration_inputs["geography"] = str(geography_path)
if package_path.exists():
calibration_inputs["calibration_package"] = str(package_path)
calibration_inputs = WorkerCalibrationInputs.from_artifact_paths(
weights_path=weights_path,
dataset_path=dataset_path,
database_path=db_path,
geography_path=geography_path,
calibration_package_path=package_path,
n_clones=n_clones,
seed=SEED,
).to_wire_dict()
return {
"fingerprint": fingerprint,
"geography_source": resolved.kind if resolved is not None else None,
Expand Down
119 changes: 76 additions & 43 deletions modal_app/local_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import sys
import traceback
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Mapping

import modal

Expand All @@ -39,6 +39,9 @@
from policyengine_us_data.build_outputs.partitioning import ( # noqa: E402
partition_weighted_work_items,
)
from policyengine_us_data.build_outputs.worker_inputs import ( # noqa: E402
WorkerCalibrationInputs,
)
from policyengine_us_data.pipeline_metadata import pipeline_node # noqa: E402
from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402
from policyengine_us_data.utils.run_context import resolve_run_id # noqa: E402
Expand Down Expand Up @@ -489,6 +492,31 @@ def _build_worker_bootstrap(
return bundle


def _build_worker_calibration_inputs(
*,
weights_path: Path,
geography_path: Path,
dataset_path: Path,
db_path: Path,
n_clones: int,
seed: int,
run_config_path: Path | None = None,
calibration_package_path: Path | None = None,
) -> WorkerCalibrationInputs:
"""Build the normalized H5 worker input payload."""

return WorkerCalibrationInputs.from_artifact_paths(
weights_path=weights_path,
geography_path=geography_path,
dataset_path=dataset_path,
database_path=db_path,
n_clones=n_clones,
seed=seed,
run_config_path=run_config_path,
calibration_package_path=calibration_package_path,
)


@pipeline_node(
PipelineNode(
id="coordinate_work_partition",
Expand Down Expand Up @@ -561,9 +589,10 @@ def run_phase(
completed: set,
branch: str,
run_id: str,
calibration_inputs: Dict[str, str],
calibration_inputs: WorkerCalibrationInputs | Mapping[str, object],
run_dir: Path,
validate: bool = True,
scope_fingerprint: str | None = None,
) -> tuple:
"""Run a single build phase, spawning workers and collecting results.

Expand All @@ -575,6 +604,9 @@ def run_phase(
"""
work_chunks = partition_work(work_items, num_workers, completed)
total_remaining = sum(len(c) for c in work_chunks)
worker_input_payload = WorkerCalibrationInputs.from_wire_dict(
calibration_inputs
).to_wire_dict()

print(f"\n--- Phase: {phase_name} ---")
print(f"Remaining work: {total_remaining} items across {len(work_chunks)} workers")
Expand All @@ -590,9 +622,11 @@ def run_phase(
handle = build_areas_worker.spawn(
branch=branch,
run_id=run_id,
scope="regional",
work_items=chunk,
calibration_inputs=calibration_inputs,
calibration_inputs=worker_input_payload,
validate=validate,
scope_fingerprint=scope_fingerprint,
)
print(f" → fc: {handle.object_id}")
handles.append(handle)
Expand Down Expand Up @@ -685,9 +719,11 @@ def run_phase(
def build_areas_worker(
branch: str,
run_id: str,
scope: str,
work_items: List[Dict],
calibration_inputs: Dict[str, str],
calibration_inputs: WorkerCalibrationInputs | Mapping[str, object],
validate: bool = True,
scope_fingerprint: str | None = None,
) -> Dict:
"""
Worker function that builds a subset of H5 files.
Expand All @@ -702,33 +738,24 @@ def build_areas_worker(
output_dir.mkdir(parents=True, exist_ok=True)

work_items_json = json.dumps(work_items)
worker_inputs = WorkerCalibrationInputs.from_wire_dict(calibration_inputs)

worker_cmd = [
*_python_cmd("-m", "modal_app.worker_script"),
"--work-items",
work_items_json,
"--weights-path",
calibration_inputs["weights"],
"--dataset-path",
calibration_inputs["dataset"],
"--db-path",
calibration_inputs["database"],
*worker_inputs.to_worker_cli_args(),
"--output-dir",
str(output_dir),
"--scope",
scope,
"--run-id",
run_id,
"--artifacts-dir",
str(Path("/pipeline/artifacts") / run_id),
]
if "geography" in calibration_inputs:
worker_cmd.extend(["--geography-path", calibration_inputs["geography"]])
if "calibration_package" in calibration_inputs:
worker_cmd.extend(
[
"--calibration-package-path",
calibration_inputs["calibration_package"],
]
)
if "n_clones" in calibration_inputs:
worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])])
if "seed" in calibration_inputs:
worker_cmd.extend(["--seed", str(calibration_inputs["seed"])])
if scope_fingerprint:
worker_cmd.extend(["--scope-fingerprint", scope_fingerprint])
repo_root = Path("/root/policyengine-us-data")
cal_dir = repo_root / "policyengine_us_data" / "calibration"
worker_cmd.extend(
Expand All @@ -753,8 +780,10 @@ def build_areas_worker(
env=os.environ.copy(),
)

if result.returncode != 0:
if result.stderr:
print(f"Worker stderr:\n{result.stderr}", file=__import__("sys").stderr)

if result.returncode != 0:
return {
"completed": [],
"failed": [f"{item['type']}:{item['id']}" for item in work_items],
Expand Down Expand Up @@ -1085,16 +1114,16 @@ def coordinate_publish(
)
print("All required pipeline artifacts found on volume.")

calibration_inputs = {
"weights": str(weights_path),
"geography": str(geography_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
"seed": 42,
}
if calibration_package_path.exists():
calibration_inputs["calibration_package"] = str(calibration_package_path)
calibration_inputs = _build_worker_calibration_inputs(
weights_path=weights_path,
geography_path=geography_path,
dataset_path=dataset_path,
db_path=db_path,
n_clones=n_clones,
seed=42,
run_config_path=config_json_path,
calibration_package_path=calibration_package_path,
)
validate_artifacts(config_json_path, artifacts)

if validate:
Expand Down Expand Up @@ -1215,6 +1244,7 @@ def coordinate_publish(
calibration_inputs=calibration_inputs,
run_dir=run_dir,
validate=validate,
scope_fingerprint=fingerprint,
)

accumulated_errors = []
Expand Down Expand Up @@ -1396,14 +1426,15 @@ def coordinate_national_publish(
)
print("All required national pipeline artifacts found.")

calibration_inputs = {
"weights": str(weights_path),
"geography": str(geography_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
"seed": 42,
}
calibration_inputs = _build_worker_calibration_inputs(
weights_path=weights_path,
geography_path=geography_path,
dataset_path=dataset_path,
db_path=db_path,
n_clones=n_clones,
seed=42,
run_config_path=config_json_path,
)
validate_artifacts(
config_json_path,
artifacts,
Expand Down Expand Up @@ -1444,9 +1475,11 @@ def coordinate_national_publish(
worker_result = build_areas_worker.remote(
branch=branch,
run_id=run_id,
scope="national",
work_items=work_items,
calibration_inputs=calibration_inputs,
calibration_inputs=calibration_inputs.to_wire_dict(),
validate=validate,
scope_fingerprint=fingerprint,
)

print(
Expand Down
Loading