From 1a57a90acea6b17748c2ecb9937bbf6bf835b00a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 18 Apr 2026 00:24:08 +0200 Subject: [PATCH 01/17] Support stage 6 explicit report specs --- .../routes/report_output_routes.py | 47 ++++++ policyengine_api/routes/simulation_routes.py | 20 +++ .../services/report_output_service.py | 100 ++++++++++++- .../services/report_spec_service.py | 25 ++++ .../services/simulation_service.py | 72 +++++++-- .../services/test_report_output_service.py | 4 +- .../unit/services/test_simulation_service.py | 4 +- tests/unit/test_stage5_routes.py | 141 ++++++++++++++++++ 8 files changed, 388 insertions(+), 25 deletions(-) diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 48a2ac43a..03430c53f 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -11,6 +11,25 @@ report_output_bp = Blueprint("report_output", __name__) report_output_service = ReportOutputService() +RUN_METADATA_FIELDS = ( + "country_package_version", + "policyengine_version", + "data_version", + "runtime_app_name", + "resolved_dataset", +) + + +def _parse_report_run_metadata(payload: dict) -> dict[str, str | None]: + metadata: dict[str, str | None] = {} + for field_name in RUN_METADATA_FIELDS: + if field_name not in payload: + continue + value = payload.get(field_name) + if value is not None and not isinstance(value, str): + raise BadRequest(f"{field_name} must be a string or null") + metadata[field_name] = value + return metadata @report_output_bp.route("//report", methods=["POST"]) @@ -36,6 +55,8 @@ def create_report_output(country_id: str) -> Response: simulation_1_id = payload.get("simulation_1_id") simulation_2_id = payload.get("simulation_2_id") # Optional year = payload.get("year", CURRENT_YEAR) # Default to current year as string + report_spec_payload = payload.get("report_spec") + report_spec_schema_version = payload.get("report_spec_schema_version") # Validate required fields if simulation_1_id is None: @@ -46,6 +67,26 @@ def create_report_output(country_id: str) -> Response: raise BadRequest("simulation_2_id must be an integer or null") if not isinstance(year, str): raise BadRequest("year must be a string") + if report_spec_payload is not None and not isinstance(report_spec_payload, dict): + raise BadRequest("report_spec must be an object") + if report_spec_schema_version is not None and not isinstance( + report_spec_schema_version, int + ): + raise BadRequest("report_spec_schema_version must be an integer") + + report_spec = None + if report_spec_payload is not None: + try: + report_spec = report_output_service.parse_report_spec_payload( + report_spec_payload, + ( + report_spec_schema_version + if report_spec_schema_version is not None + else 1 + ), + ) + except ValueError as exc: + raise BadRequest(str(exc)) from exc try: # Check if report already exists with these simulation IDs and year @@ -61,6 +102,8 @@ def create_report_output(country_id: str) -> Response: report_output_service.ensure_report_output_dual_write_state( existing_report["id"], country_id=country_id, + explicit_report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) ) # Report already exists, return it with 200 status @@ -82,6 +125,8 @@ def create_report_output(country_id: str) -> Response: simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) response_body = dict( @@ -175,6 +220,7 @@ def update_report_output(country_id: str) -> Response: report_id = payload.get("id") output = payload.get("output") error_message = payload.get("error_message") + version_manifest_overrides = _parse_report_run_metadata(payload) print(f"Updating report #{report_id} for country {country_id}") # Validate status if provided @@ -204,6 +250,7 @@ def update_report_output(country_id: str) -> Response: status=status, output=output, error_message=error_message, + version_manifest_overrides=version_manifest_overrides, ) if not success: diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index f2bacd6cb..2da34a962 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -10,6 +10,24 @@ simulation_bp = Blueprint("simulation", __name__) simulation_service = SimulationService() +RUN_METADATA_FIELDS = ( + "country_package_version", + "policyengine_version", + "data_version", + "runtime_app_name", +) + + +def _parse_simulation_run_metadata(payload: dict) -> dict[str, str | None]: + metadata: dict[str, str | None] = {} + for field_name in RUN_METADATA_FIELDS: + if field_name not in payload: + continue + value = payload.get(field_name) + if value is not None and not isinstance(value, str): + raise BadRequest(f"{field_name} must be a string or null") + metadata[field_name] = value + return metadata @simulation_bp.route("//simulation", methods=["POST"]) @@ -175,6 +193,7 @@ def update_simulation(country_id: str) -> Response: simulation_id = payload.get("id") output = payload.get("output") error_message = payload.get("error_message") + version_manifest_overrides = _parse_simulation_run_metadata(payload) print(f"Updating simulation #{simulation_id} for country {country_id}") # Validate status if provided @@ -200,6 +219,7 @@ def update_simulation(country_id: str) -> Response: status=status, output=output, error_message=error_message, + version_manifest_overrides=version_manifest_overrides, ) if not success: diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 38b5704fa..346a6eeeb 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -8,6 +8,7 @@ from policyengine_api.services.report_spec_service import ( ECONOMY_REPORT_KINDS, ReportSpec, + REPORT_SPEC_SCHEMA_VERSION, ReportSpecService, ) from policyengine_api.services.run_sync_utils import ( @@ -282,12 +283,13 @@ def _build_version_manifest( report_spec: ReportSpec | None, simulation_1: dict | None = None, simulation_2: dict | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict[str, str | None]: resolved_dataset = None if report_spec is not None and report_spec.report_kind in ECONOMY_REPORT_KINDS: resolved_dataset = report_spec.dataset - return { + version_manifest = { "country_package_version": self._derive_report_country_package_version( simulation_1, simulation_2 ), @@ -300,6 +302,10 @@ def _build_version_manifest( "resolved_dataset": resolved_dataset, "resolved_options_hash": None, } + for key, value in (version_manifest_overrides or {}).items(): + if key in version_manifest and value is not None: + version_manifest[key] = value + return version_manifest def _get_report_spec_status(self, report_spec: ReportSpec) -> str: if report_spec.report_kind in ECONOMY_REPORT_KINDS: @@ -312,10 +318,59 @@ def _upsert_report_spec_in_transaction( report_output: dict, simulation_1: dict | None, simulation_2: dict | None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, ) -> ReportSpec | None: if simulation_1 is None: + if explicit_report_spec is not None: + raise ValueError( + "Explicit report specs require linked simulations to be present" + ) return None + if explicit_report_spec is not None: + schema_version = ( + report_spec_schema_version + if report_spec_schema_version is not None + else REPORT_SPEC_SCHEMA_VERSION + ) + self.report_spec_service._validate_schema_version(schema_version) + self.report_spec_service.validate_report_spec_matches_context( + report_output, + explicit_report_spec, + simulation_1, + simulation_2, + ) + report_spec_status = "explicit" + existing_spec = parse_json_field(report_output.get("report_spec_json")) + if ( + existing_spec != explicit_report_spec.model_dump() + or report_output.get("report_kind") + != explicit_report_spec.report_kind + or report_output.get("report_spec_schema_version") != schema_version + or report_output.get("report_spec_status") != report_spec_status + ): + tx.query( + """ + UPDATE report_outputs + SET report_kind = ?, report_spec_json = ?, + report_spec_schema_version = ?, report_spec_status = ? + WHERE id = ? + """, + ( + explicit_report_spec.report_kind, + explicit_report_spec.model_dump_json(), + schema_version, + report_spec_status, + report_output["id"], + ), + ) + report_output["report_kind"] = explicit_report_spec.report_kind + report_output["report_spec_json"] = explicit_report_spec.model_dump() + report_output["report_spec_schema_version"] = schema_version + report_output["report_spec_status"] = report_spec_status + return explicit_report_spec + try: report_spec = self.report_spec_service.build_report_spec( report_output=report_output, @@ -545,6 +600,9 @@ def _ensure_report_output_dual_write_state_in_transaction( report_output_id: int, *, country_id: str | None = None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: report_output = self._get_report_output_row( report_output_id, @@ -562,6 +620,8 @@ def _ensure_report_output_dual_write_state_in_transaction( bootstrap_dual_write_state=True, ) except ValueError as exc: + if explicit_report_spec is not None: + raise print( "Skipping linked simulation sync for report output " f"#{report_output_id}. Details: {str(exc)}" @@ -573,12 +633,15 @@ def _ensure_report_output_dual_write_state_in_transaction( report_output, simulation_1, simulation_2, + explicit_report_spec=explicit_report_spec, + report_spec_schema_version=report_spec_schema_version, ) version_manifest = self._build_version_manifest( report_output, report_spec=report_spec, simulation_1=simulation_1, simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, ) runs_descending = self._list_report_runs_descending( report_output_id, queryer=tx @@ -635,15 +698,31 @@ def ensure_report_output_dual_write_state( self, report_output_id: int, country_id: str | None = None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: return database.transaction( lambda tx: self._ensure_report_output_dual_write_state_in_transaction( tx, report_output_id, country_id=country_id, + explicit_report_spec=explicit_report_spec, + report_spec_schema_version=report_spec_schema_version, + version_manifest_overrides=version_manifest_overrides, ) ) + def parse_report_spec_payload( + self, + raw_report_spec: dict, + schema_version: int = REPORT_SPEC_SCHEMA_VERSION, + ) -> ReportSpec: + return self.report_spec_service.parse_report_spec( + raw_report_spec, + schema_version=schema_version, + ) + def get_stored_report_output( self, country_id: str, report_output_id: int ) -> dict | None: @@ -764,6 +843,8 @@ def create_report_output( simulation_1_id: int, simulation_2_id: int | None = None, year: str = "2025", + report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, ) -> dict: """ Create a new report output record with pending status. @@ -789,6 +870,8 @@ def tx_callback(tx): tx, existing_report["id"], country_id=country_id, + explicit_report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) self._require_simulation_exists( @@ -850,6 +933,8 @@ def tx_callback(tx): tx, created_report["id"], country_id=country_id, + explicit_report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) return database.transaction(tx_callback) @@ -899,6 +984,7 @@ def update_report_output( status: str | None = None, output: str | None = None, error_message: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> bool: """ Update a report output record with results or error. @@ -921,7 +1007,7 @@ def update_report_output( update_fields.append("error_message = ?") update_values.append(error_message) - if not update_fields: + if not update_fields and not version_manifest_overrides: print("No fields to update") return False @@ -943,14 +1029,16 @@ def tx_callback(tx): "pending or running report run" ) - tx.query( - f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", - (*update_values, report_id, country_id), - ) + if update_fields: + tx.query( + f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", + (*update_values, report_id, country_id), + ) self._ensure_report_output_dual_write_state_in_transaction( tx, report_id, country_id=country_id, + version_manifest_overrides=version_manifest_overrides, ) database.transaction(tx_callback) diff --git a/policyengine_api/services/report_spec_service.py b/policyengine_api/services/report_spec_service.py index b81cc566f..aa77e5fa2 100644 --- a/policyengine_api/services/report_spec_service.py +++ b/policyengine_api/services/report_spec_service.py @@ -211,6 +211,20 @@ def _validate_report_spec_matches_row( self, report_output: dict, report_spec: ReportSpec ) -> None: simulation_1, simulation_2 = self._get_linked_simulations(report_output) + self.validate_report_spec_matches_context( + report_output, + report_spec, + simulation_1, + simulation_2, + ) + + def validate_report_spec_matches_context( + self, + report_output: dict, + report_spec: ReportSpec, + simulation_1: dict, + simulation_2: dict | None = None, + ) -> None: inferred_report_kind = self.infer_report_kind(simulation_1, simulation_2) if report_spec.country_id != report_output["country_id"]: raise ValueError("Report spec country must match report output country") @@ -268,6 +282,17 @@ def _validate_report_spec_matches_row( "Report spec reform_policy_id must match linked simulations" ) + def parse_report_spec( + self, + raw_spec: dict, + schema_version: int = REPORT_SPEC_SCHEMA_VERSION, + ) -> ReportSpec: + self._validate_schema_version(schema_version) + report_kind = raw_spec.get("report_kind") + if report_kind is None: + raise ValueError("Report spec is missing report_kind") + return self._parse_report_spec(report_kind, raw_spec) + def infer_report_kind( self, simulation_1: dict, diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index e5582ee17..97a7027fa 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -62,14 +62,22 @@ def _find_existing_simulation_row( ).fetchone() return dict(row) if row is not None else None - def _build_version_manifest(self, simulation: dict) -> dict[str, str | None]: - return { + def _build_version_manifest( + self, + simulation: dict, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + version_manifest = { "country_package_version": simulation.get("api_version"), "policyengine_version": None, "data_version": None, "runtime_app_name": None, "simulation_cache_version": None, } + for key, value in (version_manifest_overrides or {}).items(): + if key in version_manifest and value is not None: + version_manifest[key] = value + return version_manifest def _list_simulation_runs_descending( self, simulation_id: int, *, queryer=None @@ -134,8 +142,12 @@ def _run_matches_parent( run: dict, simulation: dict, simulation_spec: SimulationSpec, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> bool: - version_manifest = self._build_version_manifest(simulation) + version_manifest = self._build_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) return ( run["status"] == simulation["status"] and run.get("output") == simulation.get("output") @@ -152,9 +164,16 @@ def _run_matches_parent( ) def _insert_bootstrap_run( - self, tx, simulation: dict, simulation_spec: SimulationSpec + self, + tx, + simulation: dict, + simulation_spec: SimulationSpec, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> None: - version_manifest = self._build_version_manifest(simulation) + version_manifest = self._build_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) tx.query( """ INSERT INTO simulation_runs ( @@ -194,8 +213,12 @@ def _update_simulation_run_in_transaction( run_id: str, simulation: dict, simulation_spec: SimulationSpec, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> None: - version_manifest = self._build_version_manifest(simulation) + version_manifest = self._build_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) tx.query( """ UPDATE simulation_runs @@ -253,6 +276,7 @@ def _ensure_simulation_dual_write_state_in_transaction( simulation_id: int, *, country_id: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: simulation = self._get_simulation_row( simulation_id, @@ -268,7 +292,12 @@ def _ensure_simulation_dual_write_state_in_transaction( simulation_id, queryer=tx ) if not runs_descending: - self._insert_bootstrap_run(tx, simulation, simulation_spec) + self._insert_bootstrap_run( + tx, + simulation, + simulation_spec, + version_manifest_overrides=version_manifest_overrides, + ) runs_descending = self._list_simulation_runs_descending( simulation_id, queryer=tx ) @@ -278,12 +307,14 @@ def _ensure_simulation_dual_write_state_in_transaction( mutable_run, simulation, simulation_spec, + version_manifest_overrides=version_manifest_overrides, ): self._update_simulation_run_in_transaction( tx, run_id=mutable_run["id"], simulation=simulation, simulation_spec=simulation_spec, + version_manifest_overrides=version_manifest_overrides, ) runs_descending = self._list_simulation_runs_descending( simulation_id, queryer=tx @@ -300,13 +331,17 @@ def _ensure_simulation_dual_write_state_in_transaction( return refreshed_simulation def ensure_simulation_dual_write_state( - self, simulation_id: int, country_id: str | None = None + self, + simulation_id: int, + country_id: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: return database.transaction( lambda tx: self._ensure_simulation_dual_write_state_in_transaction( tx, simulation_id, country_id=country_id, + version_manifest_overrides=version_manifest_overrides, ) ) @@ -459,6 +494,7 @@ def update_simulation( status: str | None = None, output: str | None = None, error_message: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> bool: """ Update a simulation record with results or error. @@ -497,12 +533,16 @@ def update_simulation( # previous code appended api_version unconditionally, so # the "no fields to update" guard below never fired and a # PATCH with an empty body still touched the row. - if not update_fields: + if not update_fields and not version_manifest_overrides: print("No fields to update") return False - update_fields.append("api_version = ?") - update_values.append(api_version) + # Metadata-only PATCHes update the run manifest, not the + # parent simulation row; only append api_version when + # caller-supplied parent fields are changing. + if update_fields: + update_fields.append("api_version = ?") + update_values.append(api_version) def tx_callback(tx): simulation = self._get_simulation_row( @@ -514,14 +554,16 @@ def tx_callback(tx): if simulation is None: raise ValueError(f"Simulation #{simulation_id} not found") - tx.query( - f"UPDATE simulations SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", - (*update_values, simulation_id, country_id), - ) + if update_fields: + tx.query( + f"UPDATE simulations SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", + (*update_values, simulation_id, country_id), + ) self._ensure_simulation_dual_write_state_in_transaction( tx, simulation_id, country_id=country_id, + version_manifest_overrides=version_manifest_overrides, ) database.transaction(tx_callback) diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 55ee2ff62..e0ae0a827 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -1731,7 +1731,7 @@ def test_create_report_output_rolls_back_parent_insert_on_dual_write_failure( policy_id=34, ) - def fail_dual_write(tx, report_output_id, *, country_id=None): + def fail_dual_write(tx, report_output_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( @@ -1773,7 +1773,7 @@ def test_update_report_output_rolls_back_parent_update_on_dual_write_failure( year="2025", ) - def fail_dual_write(tx, report_output_id, *, country_id=None): + def fail_dual_write(tx, report_output_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index 34116287f..8c442df03 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -227,7 +227,7 @@ def test_create_simulation_reuses_existing_row_and_bootstraps_dual_write( def test_create_simulation_rolls_back_parent_insert_on_dual_write_failure( self, test_db, monkeypatch ): - def fail_dual_write(tx, simulation_id, *, country_id=None): + def fail_dual_write(tx, simulation_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( @@ -454,7 +454,7 @@ def test_update_simulation_rolls_back_parent_update_on_dual_write_failure( policy_id=15, ) - def fail_dual_write(tx, simulation_id, *, country_id=None): + def fail_dual_write(tx, simulation_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index ec9f34e1b..35ea809f8 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -167,6 +167,73 @@ def test_post_report_output_returns_timestamp_fields_for_new_and_existing_report assert existing_report["finished_at"] is None +def test_create_report_output_with_explicit_spec_persists_it(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=45, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=46, + ) + + client = create_test_client() + response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ny", + "baseline_policy_id": 45, + "reform_policy_id": 46, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + + assert response.status_code == 201 + report_id = response.get_json()["result"]["id"] + + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_id,), + ).fetchone() + assert stored_report["report_kind"] == "economy_comparison" + assert stored_report["report_spec_schema_version"] == 1 + assert stored_report["report_spec_status"] == "explicit" + + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_id,), + ).fetchone() + assert run is not None + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + def test_create_report_output_missing_primary_simulation_returns_bad_request(test_db): client = create_test_client() response = client.post( @@ -258,6 +325,39 @@ def test_patch_simulation_wrong_country_returns_not_found_and_does_not_mutate(te assert stored_simulation["output"] is None +def test_patch_simulation_persists_run_metadata_fields(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_metadata", + population_type="household", + policy_id=47, + ) + + client = create_test_client() + response = client.patch( + "/us/simulation", + json={ + "id": simulation["id"], + "status": "complete", + "output": json.dumps({"ok": True}), + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + "data_version": "2026.04.16", + "runtime_app_name": "policyengine-app-v2", + }, + ) + + assert response.status_code == 200 + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (simulation["id"],), + ).fetchone() + assert run["country_package_version"] == "1.620.0" + assert run["policyengine_version"] == "0.94.2" + assert run["data_version"] == "2026.04.16" + assert run["runtime_app_name"] == "policyengine-app-v2" + + def test_get_report_output_wrong_country_returns_not_found(test_db): test_db.query( """ @@ -580,3 +680,44 @@ def test_patch_report_output_complete_promotes_active_rerun_route_path(test_db): ).fetchone() assert stored_report["active_run_id"] is None assert stored_report["latest_successful_run_id"] == rerun["id"] + + +def test_patch_report_output_persists_run_metadata_fields(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/wa", + population_type="geography", + policy_id=48, + ) + report_output = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report_output["id"], + "status": "complete", + "output": json.dumps({"result": "ok"}), + "country_package_version": "1.621.0", + "policyengine_version": "0.95.0", + "data_version": "2026.04.17", + "runtime_app_name": "policyengine-app-v2", + "resolved_dataset": "enhanced_us_household", + }, + ) + + assert response.status_code == 200 + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_output["id"],), + ).fetchone() + assert run["country_package_version"] == "1.621.0" + assert run["policyengine_version"] == "0.95.0" + assert run["data_version"] == "2026.04.17" + assert run["runtime_app_name"] == "policyengine-app-v2" + assert run["resolved_dataset"] == "enhanced_us_household" From 16d843e21b8a5556e3f95d0453870a6d24a2d2ba Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 18 Apr 2026 00:59:21 +0200 Subject: [PATCH 02/17] Preserve explicit stage 6 report specs --- .../services/report_output_service.py | 170 ++++++++++---- .../services/test_report_output_service.py | 209 ++++++++++++++++++ tests/unit/test_stage5_routes.py | 146 ++++++++++++ 3 files changed, 477 insertions(+), 48 deletions(-) diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 346a6eeeb..4407840ae 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -312,65 +312,94 @@ def _get_report_spec_status(self, report_spec: ReportSpec) -> str: return "backfilled_assumed" return "explicit" - def _upsert_report_spec_in_transaction( + def _persist_explicit_report_spec_in_transaction( self, tx, report_output: dict, - simulation_1: dict | None, + simulation_1: dict, simulation_2: dict | None, - explicit_report_spec: ReportSpec | None = None, + explicit_report_spec: ReportSpec, report_spec_schema_version: int | None = None, + ) -> ReportSpec: + schema_version = ( + report_spec_schema_version + if report_spec_schema_version is not None + else REPORT_SPEC_SCHEMA_VERSION + ) + self.report_spec_service._validate_schema_version(schema_version) + self.report_spec_service.validate_report_spec_matches_context( + report_output, + explicit_report_spec, + simulation_1, + simulation_2, + ) + report_spec_status = "explicit" + existing_spec = parse_json_field(report_output.get("report_spec_json")) + if ( + existing_spec != explicit_report_spec.model_dump() + or report_output.get("report_kind") != explicit_report_spec.report_kind + or report_output.get("report_spec_schema_version") != schema_version + or report_output.get("report_spec_status") != report_spec_status + ): + tx.query( + """ + UPDATE report_outputs + SET report_kind = ?, report_spec_json = ?, + report_spec_schema_version = ?, report_spec_status = ? + WHERE id = ? + """, + ( + explicit_report_spec.report_kind, + explicit_report_spec.model_dump_json(), + schema_version, + report_spec_status, + report_output["id"], + ), + ) + report_output["report_kind"] = explicit_report_spec.report_kind + report_output["report_spec_json"] = explicit_report_spec.model_dump() + report_output["report_spec_schema_version"] = schema_version + report_output["report_spec_status"] = report_spec_status + return explicit_report_spec + + def _load_existing_explicit_report_spec( + self, + report_output: dict, + simulation_1: dict, + simulation_2: dict | None, ) -> ReportSpec | None: - if simulation_1 is None: - if explicit_report_spec is not None: - raise ValueError( - "Explicit report specs require linked simulations to be present" - ) + if report_output.get("report_spec_status") != "explicit": return None - if explicit_report_spec is not None: - schema_version = ( - report_spec_schema_version - if report_spec_schema_version is not None - else REPORT_SPEC_SCHEMA_VERSION + raw_spec = parse_json_field(report_output.get("report_spec_json")) + if raw_spec is None: + raise ValueError( + "Stored explicit report spec is missing report_spec_json" ) - self.report_spec_service._validate_schema_version(schema_version) - self.report_spec_service.validate_report_spec_matches_context( - report_output, - explicit_report_spec, - simulation_1, - simulation_2, + + report_spec = self.report_spec_service.parse_report_spec( + raw_spec, + schema_version=report_output.get("report_spec_schema_version"), + ) + if report_output.get("report_kind") != report_spec.report_kind: + raise ValueError( + "Stored explicit report kind must match stored report spec" ) - report_spec_status = "explicit" - existing_spec = parse_json_field(report_output.get("report_spec_json")) - if ( - existing_spec != explicit_report_spec.model_dump() - or report_output.get("report_kind") - != explicit_report_spec.report_kind - or report_output.get("report_spec_schema_version") != schema_version - or report_output.get("report_spec_status") != report_spec_status - ): - tx.query( - """ - UPDATE report_outputs - SET report_kind = ?, report_spec_json = ?, - report_spec_schema_version = ?, report_spec_status = ? - WHERE id = ? - """, - ( - explicit_report_spec.report_kind, - explicit_report_spec.model_dump_json(), - schema_version, - report_spec_status, - report_output["id"], - ), - ) - report_output["report_kind"] = explicit_report_spec.report_kind - report_output["report_spec_json"] = explicit_report_spec.model_dump() - report_output["report_spec_schema_version"] = schema_version - report_output["report_spec_status"] = report_spec_status - return explicit_report_spec + self.report_spec_service.validate_report_spec_matches_context( + report_output, + report_spec, + simulation_1, + simulation_2, + ) + return report_spec + def _derive_and_upsert_report_spec_in_transaction( + self, + tx, + report_output: dict, + simulation_1: dict, + simulation_2: dict | None, + ) -> ReportSpec | None: try: report_spec = self.report_spec_service.build_report_spec( report_output=report_output, @@ -414,6 +443,51 @@ def _upsert_report_spec_in_transaction( return report_spec + def _upsert_report_spec_in_transaction( + self, + tx, + report_output: dict, + simulation_1: dict | None, + simulation_2: dict | None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, + ) -> ReportSpec | None: + if simulation_1 is None: + if explicit_report_spec is not None: + raise ValueError( + "Explicit report specs require linked simulations to be present" + ) + if report_output.get("report_spec_status") == "explicit": + raise ValueError( + "Stored explicit report specs require linked simulations to be present" + ) + return None + + if explicit_report_spec is not None: + return self._persist_explicit_report_spec_in_transaction( + tx, + report_output, + simulation_1, + simulation_2, + explicit_report_spec, + report_spec_schema_version=report_spec_schema_version, + ) + + stored_explicit_report_spec = self._load_existing_explicit_report_spec( + report_output, + simulation_1, + simulation_2, + ) + if stored_explicit_report_spec is not None: + return stored_explicit_report_spec + + return self._derive_and_upsert_report_spec_in_transaction( + tx, + report_output, + simulation_1, + simulation_2, + ) + def _run_matches_parent( self, run: dict, diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index e0ae0a827..0200cf784 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -1465,6 +1465,72 @@ def test_update_report_output_updates_dual_write_state(self, test_db): assert run["output"] == output_json assert run["id"] == stored_report["latest_successful_run_id"] + def test_update_report_output_preserves_stored_explicit_report_spec(self, test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/co", + population_type="geography", + policy_id=61, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/co", + population_type="geography", + policy_id=62, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/co", + "baseline_policy_id": 61, + "reform_policy_id": 62, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + success = service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "ok"}), + ) + + assert success is True + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (created_report["id"],), + ).fetchone() + assert stored_report["report_spec_status"] == "explicit" + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + def test_update_report_output_bootstraps_missing_run_state(self, test_db): simulation_1 = simulation_service.create_simulation( country_id="us", @@ -1895,3 +1961,146 @@ def test_ensure_report_output_dual_write_state_bootstraps_linked_simulations( ).fetchone() assert simulation_1_run is not None assert simulation_2_run is not None + + def test_ensure_report_output_dual_write_state_reuses_stored_explicit_report_spec( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/il", + population_type="geography", + policy_id=63, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/il", + population_type="geography", + policy_id=64, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/il", + "baseline_policy_id": 63, + "reform_policy_id": 64, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + synced_report = service.ensure_report_output_dual_write_state( + created_report["id"], + country_id="us", + ) + + assert synced_report["report_spec_status"] == "explicit" + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (created_report["id"],), + ).fetchone() + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + def test_update_report_output_invalid_stored_explicit_report_spec_fails_closed( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/mi", + population_type="geography", + policy_id=65, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/mi", + population_type="geography", + policy_id=66, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/mi", + "baseline_policy_id": 65, + "reform_policy_id": 66, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + corrupted_spec = { + **explicit_report_spec.model_dump(), + "region": "state/ca", + } + test_db.query( + """ + UPDATE report_outputs + SET report_spec_json = ? + WHERE id = ? + """, + ( + json.dumps(corrupted_spec), + created_report["id"], + ), + ) + + with pytest.raises( + ValueError, match="Report spec region must match linked simulations" + ): + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "should_rollback"}), + ) + + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (created_report["id"],), + ).fetchone() + assert stored_report["status"] == "pending" + assert stored_report["output"] is None + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + assert run is not None + assert run["status"] == "pending" + assert run["output"] is None diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index 35ea809f8..407266b94 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -721,3 +721,149 @@ def test_patch_report_output_persists_run_metadata_fields(test_db): assert run["data_version"] == "2026.04.17" assert run["runtime_app_name"] == "policyengine-app-v2" assert run["resolved_dataset"] == "enhanced_us_household" + + +def test_patch_report_output_preserves_stored_explicit_report_spec(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/or", + population_type="geography", + policy_id=49, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/or", + population_type="geography", + policy_id=50, + ) + + client = create_test_client() + create_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/or", + "baseline_policy_id": 49, + "reform_policy_id": 50, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + report_id = create_response.get_json()["result"]["id"] + + patch_response = client.patch( + "/us/report", + json={ + "id": report_id, + "status": "complete", + "output": json.dumps({"result": "ok"}), + }, + ) + + assert patch_response.status_code == 200 + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_id,), + ).fetchone() + assert stored_report["report_spec_status"] == "explicit" + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_id,), + ).fetchone() + assert run is not None + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + +def test_patch_report_output_metadata_only_preserves_stored_explicit_report_spec(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nj", + population_type="geography", + policy_id=51, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nj", + population_type="geography", + policy_id=52, + ) + + client = create_test_client() + create_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/nj", + "baseline_policy_id": 51, + "reform_policy_id": 52, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + report_id = create_response.get_json()["result"]["id"] + + patch_response = client.patch( + "/us/report", + json={ + "id": report_id, + "policyengine_version": "0.95.1", + "runtime_app_name": "policyengine-app-v2", + }, + ) + + assert patch_response.status_code == 200 + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_id,), + ).fetchone() + assert stored_report["report_spec_status"] == "explicit" + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_id,), + ).fetchone() + assert run is not None + assert run["policyengine_version"] == "0.95.1" + assert run["runtime_app_name"] == "policyengine-app-v2" + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} From b0340c9f03f8cbd1d60e9ba262f56de76bc0a0d1 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 18 Apr 2026 02:58:10 +0200 Subject: [PATCH 03/17] Add report identity hash for report outputs --- policyengine_api/data/initialise.sql | 2 + policyengine_api/data/initialise_local.sql | 2 + .../services/report_spec_service.py | 65 +++++++++ tests/unit/data/test_run_schema.py | 4 + .../unit/services/test_report_spec_service.py | 133 ++++++++++++++++++ 5 files changed, 206 insertions(+) diff --git a/policyengine_api/data/initialise.sql b/policyengine_api/data/initialise.sql index 085f31c0b..4aacc9112 100644 --- a/policyengine_api/data/initialise.sql +++ b/policyengine_api/data/initialise.sql @@ -135,6 +135,8 @@ CREATE TABLE IF NOT EXISTS report_outputs ( report_spec_json JSON DEFAULT NULL, report_spec_schema_version INT DEFAULT NULL, report_spec_status VARCHAR(32) DEFAULT NULL, + report_identity_hash VARCHAR(64) DEFAULT NULL, + report_identity_schema_version INT DEFAULT NULL, active_run_id CHAR(36) DEFAULT NULL, latest_successful_run_id CHAR(36) DEFAULT NULL ); diff --git a/policyengine_api/data/initialise_local.sql b/policyengine_api/data/initialise_local.sql index 53a37b4c8..b8530be65 100644 --- a/policyengine_api/data/initialise_local.sql +++ b/policyengine_api/data/initialise_local.sql @@ -147,6 +147,8 @@ CREATE TABLE IF NOT EXISTS report_outputs ( report_spec_json JSON DEFAULT NULL, report_spec_schema_version INT DEFAULT NULL, report_spec_status VARCHAR(32) DEFAULT NULL, + report_identity_hash VARCHAR(64) DEFAULT NULL, + report_identity_schema_version INT DEFAULT NULL, active_run_id CHAR(36) DEFAULT NULL, latest_successful_run_id CHAR(36) DEFAULT NULL ); diff --git a/policyengine_api/services/report_spec_service.py b/policyengine_api/services/report_spec_service.py index aa77e5fa2..457d0dc0e 100644 --- a/policyengine_api/services/report_spec_service.py +++ b/policyengine_api/services/report_spec_service.py @@ -1,12 +1,15 @@ import json +import hashlib from typing import Any, Literal from pydantic import BaseModel, Field from sqlalchemy.engine.row import Row from policyengine_api.data import database +from policyengine_api.data.congressional_districts import normalize_us_region REPORT_SPEC_SCHEMA_VERSION = 1 +REPORT_IDENTITY_SCHEMA_VERSION = 1 REPORT_SPEC_STATUSES = {"explicit", "backfilled_assumed"} HOUSEHOLD_REPORT_KINDS = {"household_single", "household_comparison"} ECONOMY_REPORT_KINDS = {"economy_single", "economy_comparison"} @@ -48,6 +51,14 @@ def _validate_schema_version(self, schema_version: int | None) -> None: f"Unsupported report spec schema version: {schema_version}" ) + def _validate_report_identity_schema_version( + self, schema_version: int | None + ) -> None: + if schema_version != REPORT_IDENTITY_SCHEMA_VERSION: + raise ValueError( + f"Unsupported report identity schema version: {schema_version}" + ) + def _get_report_output_row(self, report_output_id: int) -> dict | None: row: Row | None = database.query( "SELECT * FROM report_outputs WHERE id = ?", @@ -364,6 +375,60 @@ def _parse_json_field(self, value: str | dict | None) -> dict | None: return json.loads(value) return value + def canonicalize_report_spec_for_identity( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> dict[str, Any]: + self._validate_report_identity_schema_version(schema_version) + + canonical_spec = report_spec.model_dump() + if ( + isinstance(report_spec, EconomyReportSpec) + and report_spec.country_id == "us" + ): + canonical_spec["region"] = normalize_us_region(canonical_spec["region"]) + return canonical_spec + + def serialize_canonical_report_spec_for_identity( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> str: + canonical_spec = self.canonicalize_report_spec_for_identity( + report_spec, + schema_version=schema_version, + ) + return json.dumps( + canonical_spec, + sort_keys=True, + separators=(",", ":"), + ) + + def get_report_identity_hash( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> str: + canonical_json = self.serialize_canonical_report_spec_for_identity( + report_spec, + schema_version=schema_version, + ) + return hashlib.sha256(canonical_json.encode("utf-8")).hexdigest() + + def get_report_identity( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> tuple[str, int]: + return ( + self.get_report_identity_hash( + report_spec, + schema_version=schema_version, + ), + schema_version, + ) + def _parse_report_spec(self, report_kind: str, raw_spec: dict) -> ReportSpec: if report_kind in HOUSEHOLD_REPORT_KINDS: return HouseholdReportSpec.model_validate(raw_spec) diff --git a/tests/unit/data/test_run_schema.py b/tests/unit/data/test_run_schema.py index 2bcba1eff..72b0ba6c0 100644 --- a/tests/unit/data/test_run_schema.py +++ b/tests/unit/data/test_run_schema.py @@ -15,6 +15,8 @@ def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): "report_spec_json", "report_spec_schema_version", "report_spec_status", + "report_identity_hash", + "report_identity_schema_version", "active_run_id", "latest_successful_run_id", }.issubset(report_output_columns) @@ -77,6 +79,8 @@ def test_stage_one_schema_is_defined_in_both_sql_initializers(): "CREATE TABLE IF NOT EXISTS legacy_report_output_aliases", "report_spec_json", "report_spec_status", + "report_identity_hash", + "report_identity_schema_version", "simulation_spec_json", "active_run_id", "latest_successful_run_id", diff --git a/tests/unit/services/test_report_spec_service.py b/tests/unit/services/test_report_spec_service.py index f924df8db..0dd98db86 100644 --- a/tests/unit/services/test_report_spec_service.py +++ b/tests/unit/services/test_report_spec_service.py @@ -5,6 +5,7 @@ from policyengine_api.services.report_spec_service import ( EconomyReportSpec, HouseholdReportSpec, + REPORT_IDENTITY_SCHEMA_VERSION, ReportSpecService, ) from policyengine_api.services.simulation_service import SimulationService @@ -467,3 +468,135 @@ def test_rejects_unsupported_schema_version_on_read(self, test_db): report_spec_service.get_report_spec(report_output["id"]) assert "Unsupported report spec schema version" in str(exc_info.value) + + +class TestReportIdentity: + def test_canonical_identity_reuses_normalized_us_region(self): + report_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_single", + "time_period": "2027", + "region": "ca", + "baseline_policy_id": 10, + "reform_policy_id": 10, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + + canonical_spec = report_spec_service.canonicalize_report_spec_for_identity( + report_spec + ) + + assert canonical_spec["region"] == "state/ca" + + def test_equal_specs_produce_equal_hashes_despite_json_key_order(self): + first_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {"b": 2, "a": 1}, + } + ) + second_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {"a": 1, "b": 2}, + } + ) + + assert report_spec_service.get_report_identity_hash( + first_spec + ) == report_spec_service.get_report_identity_hash(second_spec) + + def test_distinct_economy_dataset_changes_identity_hash(self): + first_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + second_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "enhanced_us_household", + "target": "general", + "options": {}, + } + ) + + assert report_spec_service.get_report_identity_hash( + first_spec + ) != report_spec_service.get_report_identity_hash(second_spec) + + def test_report_identity_returns_hash_and_schema_version(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_single", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": None, + } + ) + + report_identity_hash, schema_version = report_spec_service.get_report_identity( + report_spec + ) + + assert len(report_identity_hash) == 64 + assert schema_version == REPORT_IDENTITY_SCHEMA_VERSION + + def test_rejects_unsupported_identity_schema_version(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_single", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": None, + } + ) + + with pytest.raises(ValueError) as exc_info: + report_spec_service.get_report_identity_hash( + report_spec, + schema_version=2, + ) + + assert "Unsupported report identity schema version" in str(exc_info.value) From 722a90258dd5c7a3033e29b3bfd53e4b3538d7fa Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 22 Apr 2026 22:17:09 +0200 Subject: [PATCH 04/17] Create reports by canonical report identity --- .../routes/report_output_routes.py | 3 +- .../services/report_output_service.py | 342 +++++++++++++++++- .../services/test_report_output_service.py | 142 +++++++- tests/unit/test_stage5_routes.py | 113 +++++- 4 files changed, 576 insertions(+), 24 deletions(-) diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 03430c53f..475b016c1 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -90,11 +90,12 @@ def create_report_output(country_id: str) -> Response: try: # Check if report already exists with these simulation IDs and year - existing_report = report_output_service.find_existing_report_output( + existing_report = report_output_service.find_existing_report_output_for_create( country_id=country_id, simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, ) if existing_report: diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 4407840ae..9c17105d0 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -362,6 +362,40 @@ def _persist_explicit_report_spec_in_transaction( report_output["report_spec_status"] = report_spec_status return explicit_report_spec + def _sync_report_identity_in_transaction( + self, + tx, + report_output: dict, + report_spec: ReportSpec | None, + ) -> None: + if report_spec is None: + return + + report_identity_hash, report_identity_schema_version = ( + self.report_spec_service.get_report_identity(report_spec) + ) + if ( + report_output.get("report_identity_hash") == report_identity_hash + and report_output.get("report_identity_schema_version") + == report_identity_schema_version + ): + return + + tx.query( + """ + UPDATE report_outputs + SET report_identity_hash = ?, report_identity_schema_version = ? + WHERE id = ? + """, + ( + report_identity_hash, + report_identity_schema_version, + report_output["id"], + ), + ) + report_output["report_identity_hash"] = report_identity_hash + report_output["report_identity_schema_version"] = report_identity_schema_version + def _load_existing_explicit_report_spec( self, report_output: dict, @@ -373,9 +407,7 @@ def _load_existing_explicit_report_spec( raw_spec = parse_json_field(report_output.get("report_spec_json")) if raw_spec is None: - raise ValueError( - "Stored explicit report spec is missing report_spec_json" - ) + raise ValueError("Stored explicit report spec is missing report_spec_json") report_spec = self.report_spec_service.parse_report_spec( raw_spec, @@ -710,6 +742,7 @@ def _ensure_report_output_dual_write_state_in_transaction( explicit_report_spec=explicit_report_spec, report_spec_schema_version=report_spec_schema_version, ) + self._sync_report_identity_in_transaction(tx, report_output, report_spec) version_manifest = self._build_version_manifest( report_output, report_spec=report_spec, @@ -858,6 +891,205 @@ def _find_existing_report_output_row( row = queryer.query(query, tuple(params)).fetchone() return dict(row) if row is not None else None + def _find_existing_report_output_row_by_identity( + self, + *, + country_id: str, + report_identity_hash: str, + report_identity_schema_version: int, + queryer=None, + ) -> dict | None: + queryer = queryer or database + row = queryer.query( + """ + SELECT * FROM report_outputs + WHERE country_id = ? AND report_identity_hash = ? + AND report_identity_schema_version = ? + ORDER BY id DESC + """, + ( + country_id, + report_identity_hash, + report_identity_schema_version, + ), + ).fetchone() + return dict(row) if row is not None else None + + def _list_report_output_rows_by_legacy_key( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + queryer=None, + ) -> list[dict]: + queryer = queryer or database + query = """ + SELECT * FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND year = ? + """ + params: list[int | str] = [country_id, simulation_1_id, year] + if simulation_2_id is not None: + query += " AND simulation_2_id = ?" + params.append(simulation_2_id) + else: + query += " AND simulation_2_id IS NULL" + query += " ORDER BY id DESC" + + rows = queryer.query(query, tuple(params)).fetchall() + return [dict(row) for row in rows] + + def _build_report_spec_for_create( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + queryer=None, + ) -> ReportSpec | None: + queryer = queryer or database + simulation_1 = self.simulation_service._get_simulation_row( + simulation_1_id, + queryer=queryer, + country_id=country_id, + ) + if simulation_1 is None: + return None + + simulation_2 = None + if simulation_2_id is not None: + simulation_2 = self.simulation_service._get_simulation_row( + simulation_2_id, + queryer=queryer, + country_id=country_id, + ) + if simulation_2 is None: + return None + + try: + return self.report_spec_service.build_report_spec( + report_output={ + "country_id": country_id, + "simulation_1_id": simulation_1_id, + "simulation_2_id": simulation_2_id, + "year": year, + }, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + except ValueError: + return None + + def _get_report_spec_for_identity_matching( + self, + report_output: dict, + *, + queryer=None, + ) -> ReportSpec | None: + queryer = queryer or database + try: + simulation_1, simulation_2 = self._get_linked_simulations( + report_output, + queryer=queryer, + ) + except ValueError: + return None + + raw_spec = parse_json_field(report_output.get("report_spec_json")) + if ( + raw_spec is not None + and report_output.get("report_spec_schema_version") is not None + ): + try: + report_spec = self.report_spec_service.parse_report_spec( + raw_spec, + schema_version=report_output["report_spec_schema_version"], + ) + self.report_spec_service.validate_report_spec_matches_context( + report_output, + report_spec, + simulation_1, + simulation_2, + ) + return report_spec + except ValueError: + return None + + try: + return self.report_spec_service.build_report_spec( + report_output=report_output, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + except ValueError: + return None + + def _find_existing_report_output_for_create( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + report_spec: ReportSpec | None = None, + queryer=None, + ) -> dict | None: + queryer = queryer or database + identity_report_spec = report_spec or self._build_report_spec_for_create( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + queryer=queryer, + ) + if identity_report_spec is None: + return self._find_existing_report_output_row( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + queryer=queryer, + ) + + report_identity_hash, report_identity_schema_version = ( + self.report_spec_service.get_report_identity(identity_report_spec) + ) + existing_report = self._find_existing_report_output_row_by_identity( + country_id=country_id, + report_identity_hash=report_identity_hash, + report_identity_schema_version=report_identity_schema_version, + queryer=queryer, + ) + if existing_report is not None: + return existing_report + + candidate_rows = self._list_report_output_rows_by_legacy_key( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + queryer=queryer, + ) + for candidate_row in candidate_rows: + candidate_report_spec = self._get_report_spec_for_identity_matching( + candidate_row, + queryer=queryer, + ) + if candidate_report_spec is None: + continue + candidate_identity_hash, candidate_identity_schema_version = ( + self.report_spec_service.get_report_identity(candidate_report_spec) + ) + if ( + candidate_identity_hash == report_identity_hash + and candidate_identity_schema_version == report_identity_schema_version + ): + return candidate_row + + return None + def _get_or_create_current_report_output(self, report_output: dict) -> dict: current_report = self.find_existing_report_output( country_id=report_output["country_id"], @@ -868,12 +1100,72 @@ def _get_or_create_current_report_output(self, report_output: dict) -> dict: if current_report is not None: return self._with_display_run_timestamps(current_report) - return self.create_report_output( - country_id=report_output["country_id"], - simulation_1_id=report_output["simulation_1_id"], - simulation_2_id=report_output["simulation_2_id"], - year=report_output["year"], - ) + api_version = get_report_output_cache_version(report_output["country_id"]) + + def tx_callback(tx): + existing_current_report = self._find_existing_report_output_row( + country_id=report_output["country_id"], + simulation_1_id=report_output["simulation_1_id"], + simulation_2_id=report_output["simulation_2_id"], + year=report_output["year"], + queryer=tx, + ) + if existing_current_report is not None: + return self._ensure_report_output_dual_write_state_in_transaction( + tx, + existing_current_report["id"], + country_id=report_output["country_id"], + ) + + if report_output["simulation_2_id"] is not None: + tx.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, year + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ( + report_output["country_id"], + report_output["simulation_1_id"], + report_output["simulation_2_id"], + api_version, + "pending", + report_output["year"], + ), + ) + else: + tx.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, api_version, status, year + ) VALUES (?, ?, ?, ?, ?) + """, + ( + report_output["country_id"], + report_output["simulation_1_id"], + api_version, + "pending", + report_output["year"], + ), + ) + + created_current_report = self._find_existing_report_output_row( + country_id=report_output["country_id"], + simulation_1_id=report_output["simulation_1_id"], + simulation_2_id=report_output["simulation_2_id"], + year=report_output["year"], + queryer=tx, + ) + if created_current_report is None: + raise Exception("Failed to create current runtime report output") + + return self._ensure_report_output_dual_write_state_in_transaction( + tx, + created_current_report["id"], + country_id=report_output["country_id"], + ) + + return database.transaction(tx_callback) def _alias_report_output(self, report_output_id: int, report_output: dict) -> dict: aliased_report = dict(report_output) @@ -911,6 +1203,35 @@ def find_existing_report_output( print(f"Error checking for existing report output. Details: {str(e)}") raise e + def find_existing_report_output_for_create( + self, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None = None, + year: str = "2025", + report_spec: ReportSpec | None = None, + ) -> dict | None: + try: + existing_report = self._find_existing_report_output_for_create( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + report_spec=report_spec, + ) + if existing_report is not None: + print( + "Found existing report output for create with ID: " + f"{existing_report['id']}" + ) + return existing_report + except Exception as e: + print( + "Error checking for existing report output by canonical identity. " + f"Details: {str(e)}" + ) + raise e + def create_report_output( self, country_id: str, @@ -929,11 +1250,12 @@ def create_report_output( try: def tx_callback(tx): - existing_report = self._find_existing_report_output_row( + existing_report = self._find_existing_report_output_for_create( country_id=country_id, simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, queryer=tx, ) if existing_report is not None: diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 0200cf784..1466118f8 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -488,20 +488,138 @@ def test_create_report_output_populates_economy_comparison_report_spec( if isinstance(report_spec, str): report_spec = json.loads(report_spec) assert report_spec["region"] == "state/ca" - assert report_spec["baseline_policy_id"] == 30 - assert report_spec["reform_policy_id"] == 31 - assert report_spec["dataset"] == "default" - run = test_db.query( - "SELECT * FROM report_output_runs WHERE report_output_id = ?", - (created_report["id"],), + def test_create_report_output_reuses_same_explicit_economy_spec(self, test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=32, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=33, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ny", + "baseline_policy_id": 32, + "reform_policy_id": 33, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + + first_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + second_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + assert first_report["id"] == second_report["id"] + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (first_report["id"],), ).fetchone() - assert run is not None - snapshot = run["report_spec_snapshot_json"] - if isinstance(snapshot, str): - snapshot = json.loads(snapshot) - assert snapshot["report_kind"] == "economy_comparison" - assert snapshot["region"] == "state/ca" + assert stored_report["report_identity_hash"] is not None + assert stored_report["report_identity_schema_version"] == 1 + + def test_create_report_output_distinguishes_explicit_economy_specs_by_identity( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=34, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=35, + ) + default_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + cliff_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + + first_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=default_report_spec, + report_spec_schema_version=1, + ) + second_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=cliff_report_spec, + report_spec_schema_version=1, + ) + + assert first_report["id"] != second_report["id"] + stored_reports = test_db.query( + """ + SELECT id, report_identity_hash, report_spec_json + FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND simulation_2_id = ? AND year = ? + ORDER BY id + """, + ( + "us", + baseline_simulation["id"], + reform_simulation["id"], + "2026", + ), + ).fetchall() + assert len(stored_reports) == 2 + assert ( + stored_reports[0]["report_identity_hash"] + != stored_reports[1]["report_identity_hash"] + ) class TestGetReportOutput: diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index 407266b94..d06e71c4b 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -220,6 +220,8 @@ def test_create_report_output_with_explicit_spec_persists_it(test_db): assert report_spec["dataset"] == "enhanced_us_household" assert report_spec["target"] == "cliff" assert report_spec["options"] == {"view": "tax"} + assert stored_report["report_identity_hash"] is not None + assert stored_report["report_identity_schema_version"] == 1 run = test_db.query( "SELECT * FROM report_output_runs WHERE report_output_id = ?", @@ -234,6 +236,113 @@ def test_create_report_output_with_explicit_spec_persists_it(test_db): assert snapshot["options"] == {"view": "tax"} +def test_create_report_output_same_explicit_spec_returns_existing_row(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/va", + population_type="geography", + policy_id=53, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/va", + population_type="geography", + policy_id=54, + ) + payload = { + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/va", + "baseline_policy_id": 53, + "reform_policy_id": 54, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + } + + client = create_test_client() + first_response = client.post("/us/report", json=payload) + second_response = client.post("/us/report", json=payload) + + assert first_response.status_code == 201 + assert second_response.status_code == 200 + assert ( + first_response.get_json()["result"]["id"] + == second_response.get_json()["result"]["id"] + ) + + +def test_create_report_output_distinct_explicit_specs_create_distinct_rows(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/md", + population_type="geography", + policy_id=55, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/md", + population_type="geography", + policy_id=56, + ) + + client = create_test_client() + default_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/md", + "baseline_policy_id": 55, + "reform_policy_id": 56, + "dataset": "default", + "target": "general", + "options": {}, + }, + }, + ) + cliff_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/md", + "baseline_policy_id": 55, + "reform_policy_id": 56, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + + assert default_response.status_code == 201 + assert cliff_response.status_code == 201 + assert ( + default_response.get_json()["result"]["id"] + != cliff_response.get_json()["result"]["id"] + ) + + def test_create_report_output_missing_primary_simulation_returns_bad_request(test_db): client = create_test_client() response = client.post( @@ -795,7 +904,9 @@ def test_patch_report_output_preserves_stored_explicit_report_spec(test_db): assert snapshot["options"] == {"view": "tax"} -def test_patch_report_output_metadata_only_preserves_stored_explicit_report_spec(test_db): +def test_patch_report_output_metadata_only_preserves_stored_explicit_report_spec( + test_db, +): baseline_simulation = simulation_service.create_simulation( country_id="us", population_id="state/nj", From 8de9f1f4f1dad254b288a22881a7f39ccc1f1e0f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 22 Apr 2026 22:56:44 +0200 Subject: [PATCH 05/17] Resolve report reads through aliases and canonical parents --- .../services/report_output_service.py | 151 +++++++----------- .../services/test_report_output_service.py | 141 ++++++++++------ tests/unit/test_stage5_routes.py | 56 +++++++ 3 files changed, 206 insertions(+), 142 deletions(-) diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 9c17105d0..38ace6cdf 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -5,6 +5,10 @@ from policyengine_api.constants import get_report_output_cache_version from policyengine_api.data import database +from policyengine_api.services.report_output_alias_service import ( + ReportOutputAliasService, +) +from policyengine_api.services.report_run_service import ReportRunService from policyengine_api.services.report_spec_service import ( ECONOMY_REPORT_KINDS, ReportSpec, @@ -25,6 +29,8 @@ class ReportOutputService: def __init__(self): self.report_spec_service = ReportSpecService() self.simulation_service = SimulationService() + self.report_output_alias_service = ReportOutputAliasService() + self.report_run_service = ReportRunService() def _lock_clause(self) -> str: return "" if database.local else " FOR UPDATE" @@ -860,11 +866,6 @@ def report_output_exists(self, country_id: str, report_output_id: int) -> bool: is not None ) - def _is_current_report_output(self, report_output: dict) -> bool: - return report_output.get("api_version") == get_report_output_cache_version( - report_output["country_id"] - ) - def _find_existing_report_output_row( self, *, @@ -1090,88 +1091,29 @@ def _find_existing_report_output_for_create( return None - def _get_or_create_current_report_output(self, report_output: dict) -> dict: - current_report = self.find_existing_report_output( - country_id=report_output["country_id"], - simulation_1_id=report_output["simulation_1_id"], - simulation_2_id=report_output["simulation_2_id"], - year=report_output["year"], - ) - if current_report is not None: - return self._with_display_run_timestamps(current_report) - - api_version = get_report_output_cache_version(report_output["country_id"]) - - def tx_callback(tx): - existing_current_report = self._find_existing_report_output_row( - country_id=report_output["country_id"], - simulation_1_id=report_output["simulation_1_id"], - simulation_2_id=report_output["simulation_2_id"], - year=report_output["year"], - queryer=tx, - ) - if existing_current_report is not None: - return self._ensure_report_output_dual_write_state_in_transaction( - tx, - existing_current_report["id"], - country_id=report_output["country_id"], - ) - - if report_output["simulation_2_id"] is not None: - tx.query( - """ - INSERT INTO report_outputs ( - country_id, simulation_1_id, simulation_2_id, api_version, status, year - ) VALUES (?, ?, ?, ?, ?, ?) - """, - ( - report_output["country_id"], - report_output["simulation_1_id"], - report_output["simulation_2_id"], - api_version, - "pending", - report_output["year"], - ), - ) - else: - tx.query( - """ - INSERT INTO report_outputs ( - country_id, simulation_1_id, api_version, status, year - ) VALUES (?, ?, ?, ?, ?) - """, - ( - report_output["country_id"], - report_output["simulation_1_id"], - api_version, - "pending", - report_output["year"], - ), - ) - - created_current_report = self._find_existing_report_output_row( - country_id=report_output["country_id"], - simulation_1_id=report_output["simulation_1_id"], - simulation_2_id=report_output["simulation_2_id"], - year=report_output["year"], - queryer=tx, - ) - if created_current_report is None: - raise Exception("Failed to create current runtime report output") - - return self._ensure_report_output_dual_write_state_in_transaction( - tx, - created_current_report["id"], - country_id=report_output["country_id"], - ) - - return database.transaction(tx_callback) - def _alias_report_output(self, report_output_id: int, report_output: dict) -> dict: aliased_report = dict(report_output) aliased_report["id"] = report_output_id return aliased_report + def _merge_display_run_into_report_output( + self, + report_output: dict, + display_run: dict | None, + ) -> dict: + if display_run is None: + return dict(report_output) + + result = dict(report_output) + result["status"] = display_run["status"] + result["output"] = display_run.get("output") + result["error_message"] = display_run.get("error_message") + if display_run.get("report_cache_version") is not None: + result["api_version"] = display_run["report_cache_version"] + for field in ("requested_at", "started_at", "finished_at"): + result[field] = self._format_run_timestamp(display_run.get(field)) + return result + def find_existing_report_output( self, country_id: str, @@ -1351,21 +1293,48 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No f"Invalid report output ID: {report_output_id}. Must be a positive integer." ) - report_output = self._get_report_output_row( - report_output_id, + canonical_report_output_id = ( + self.report_output_alias_service.resolve_canonical_report_output_id( + report_output_id + ) + ) + if canonical_report_output_id is None: + return None + + canonical_report_output = self._get_report_output_row( + canonical_report_output_id, country_id=country_id, ) - if report_output is None: + if canonical_report_output is None: return None - if self._is_current_report_output(report_output): - return self.ensure_report_output_dual_write_state( - report_output_id, + display_run = self.report_run_service.select_display_run( + canonical_report_output + ) + if display_run is None or ( + run_matches_report_result(display_run, canonical_report_output) + and self._run_needs_timestamp_sync( + display_run, + canonical_report_output["status"], + ) + ): + canonical_report_output = self.ensure_report_output_dual_write_state( + canonical_report_output_id, country_id=country_id, ) - - current_report = self._get_or_create_current_report_output(report_output) - return self._alias_report_output(report_output_id, current_report) + display_run = self.report_run_service.select_display_run( + canonical_report_output + ) + resolved_report_output = self._merge_display_run_into_report_output( + canonical_report_output, + display_run, + ) + if report_output_id != canonical_report_output_id: + return self._alias_report_output( + report_output_id, + resolved_report_output, + ) + return resolved_report_output except Exception as e: print( diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 1466118f8..1e48c52b5 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -3,6 +3,9 @@ from datetime import datetime, timezone from policyengine_api.constants import get_report_output_cache_version +from policyengine_api.services.report_output_alias_service import ( + ReportOutputAliasService, +) from policyengine_api.services.report_output_service import ReportOutputService from policyengine_api.services.report_run_service import ReportRunService from policyengine_api.services.run_sync_utils import select_display_report_run @@ -15,6 +18,7 @@ service = ReportOutputService() report_run_service = ReportRunService() simulation_service = SimulationService() +alias_service = ReportOutputAliasService() class TestReportOutputRunTimestamps: @@ -1305,74 +1309,111 @@ def test_find_existing_report_output_backfills_missing_timestamps(self, test_db) assert result is not None assert result["requested_at"] is not None - def test_get_report_output_resolves_stale_id_to_current_runtime_row(self, test_db): - stale_output = { - "budget": {"budgetary_impact": 1}, - "congressional_district_impact": { - "districts": [ - { - "district": "AL-01", - "average_household_income_change": 120, - "relative_household_income_change": 0.01, - } - ] - }, - } + def test_get_report_output_uses_selected_display_run_for_canonical_parent( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_display_run", + population_type="household", + policy_id=5, + ) + report_output = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=report_output["id"], + status="complete", + output=json.dumps({"budget": {"budgetary_impact": 2}}), + ) test_db.query( - """INSERT INTO report_outputs - (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) - VALUES (?, ?, ?, ?, ?, ?, ?)""", + """ + UPDATE report_outputs + SET status = ?, output = ?, api_version = ? + WHERE id = ? + """, ( - "us", - 2, + "pending", None, - "complete", - json.dumps(stale_output), "r0stale1", - "2025", + report_output["id"], ), ) - stale_record = test_db.query( - "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" - ).fetchone() + result = service.get_report_output( + country_id="us", report_output_id=report_output["id"] + ) - current_version = get_report_output_cache_version("us") + assert result is not None + assert result["id"] == report_output["id"] + assert result["status"] == "complete" + assert result["output"] == json.dumps({"budget": {"budgetary_impact": 2}}) + assert result["api_version"] == get_report_output_cache_version("us") + + def test_get_report_output_resolves_alias_to_canonical_parent_and_display_run( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_alias_display_run", + population_type="household", + policy_id=6, + ) + canonical_report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"budget": {"budgetary_impact": 3}}), + ) test_db.query( - """INSERT INTO report_outputs - (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) - VALUES (?, ?, ?, ?, ?, ?, ?)""", + """ + INSERT INTO report_outputs ( + id, country_id, simulation_1_id, simulation_2_id, status, output, api_version, year + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, ( + 999, "us", - 2, + simulation["id"], None, - "complete", - json.dumps({"budget": {"budgetary_impact": 2}}), - current_version, + "error", + json.dumps({"legacy": True}), + "r0legacy1", "2025", ), ) + alias_service.set_alias( + legacy_report_output_id=999, + canonical_report_output_id=canonical_report["id"], + ) - current_record = test_db.query( - "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" - ).fetchone() + result = service.get_report_output(country_id="us", report_output_id=999) - result = service.get_report_output( - country_id="us", report_output_id=stale_record["id"] - ) assert result is not None - assert result["id"] == stale_record["id"] - assert result["api_version"] == current_record["api_version"] - assert result["output"] == current_record["output"] + assert result["id"] == 999 + assert result["status"] == "complete" + assert result["output"] == json.dumps({"budget": {"budgetary_impact": 3}}) + assert result["api_version"] == get_report_output_cache_version("us") - def test_get_report_output_creates_current_runtime_row_for_stale_id(self, test_db): + def test_get_report_output_does_not_create_current_runtime_row_for_stale_id( + self, test_db + ): stale_version = "r0stale1" - current_version = get_report_output_cache_version("us") simulation = simulation_service.create_simulation( country_id="us", - population_id="household_stale_runtime_create", + population_id="household_stale_runtime_read", population_type="household", - policy_id=5, + policy_id=7, ) test_db.query( @@ -1392,17 +1433,15 @@ def test_get_report_output_creates_current_runtime_row_for_stale_id(self, test_d assert result is not None assert result["id"] == stale_record["id"] - assert result["api_version"] == current_version - assert result["status"] == "pending" + assert result["api_version"] == stale_version + assert result["status"] == "complete" assert result["output"] is None - current_rows = test_db.query( + rows = test_db.query( "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ? ORDER BY id ASC", ("us", simulation["id"], "2025"), ).fetchall() - assert len(current_rows) == 2 - assert current_rows[0]["api_version"] == stale_version - assert current_rows[1]["api_version"] == current_version + assert len(rows) == 1 def test_get_report_output_invalid_id(self, test_db): """Test that invalid report IDs are handled properly.""" diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index d06e71c4b..cc2631e55 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -486,6 +486,62 @@ def test_get_report_output_wrong_country_returns_not_found(test_db): assert response.status_code == 404 +def test_get_report_output_alias_resolves_to_canonical_display_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_alias", + population_type="household", + policy_id=57, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "canonical"}), + ) + test_db.query( + """ + INSERT INTO report_outputs ( + id, country_id, simulation_1_id, simulation_2_id, api_version, status, output, year + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 2001, + "us", + simulation["id"], + None, + "r0legacy1", + "error", + json.dumps({"result": "legacy"}), + "2025", + ), + ) + test_db.query( + """ + INSERT INTO legacy_report_output_aliases ( + legacy_report_output_id, canonical_report_output_id + ) VALUES (?, ?) + """, + (2001, canonical_report["id"]), + ) + + client = create_test_client() + response = client.get("/us/report/2001") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == 2001 + assert payload["result"]["status"] == "complete" + assert payload["result"]["output"] == json.dumps({"result": "canonical"}) + assert payload["result"]["api_version"] == get_report_output_cache_version("us") + + def test_patch_report_output_wrong_country_returns_not_found_and_does_not_mutate( test_db, ): From d1977ce709d5958f3cb1063d20ef141894a7a58a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 22 Apr 2026 23:26:45 +0200 Subject: [PATCH 06/17] Preserve execution metadata from existing runs --- .../services/report_output_service.py | 79 +++++++++++++--- .../services/simulation_service.py | 77 +++++++++++----- .../services/test_report_output_service.py | 91 +++++++++++++++++++ .../unit/services/test_simulation_service.py | 77 ++++++++++++++++ 4 files changed, 289 insertions(+), 35 deletions(-) diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 38ace6cdf..9a7f9770d 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -283,7 +283,18 @@ def _derive_report_country_package_version( return versions[0] return None - def _build_version_manifest( + def _merge_version_manifest_overrides( + self, + version_manifest: dict[str, str | None], + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + merged_manifest = dict(version_manifest) + for key, value in (version_manifest_overrides or {}).items(): + if key in merged_manifest and value is not None: + merged_manifest[key] = value + return merged_manifest + + def _build_bootstrap_version_manifest( self, report_output: dict, report_spec: ReportSpec | None, @@ -308,10 +319,36 @@ def _build_version_manifest( "resolved_dataset": resolved_dataset, "resolved_options_hash": None, } - for key, value in (version_manifest_overrides or {}).items(): - if key in version_manifest and value is not None: - version_manifest[key] = value - return version_manifest + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) + + def _build_existing_run_version_manifest( + self, + run: dict, + report_output: dict, + report_spec: ReportSpec | None, + simulation_1: dict | None = None, + simulation_2: dict | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + fallback_manifest = self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + version_manifest = { + key: run.get(key) + if run.get(key) is not None + else fallback_manifest.get(key) + for key in fallback_manifest + } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) def _get_report_spec_status(self, report_spec: ReportSpec) -> str: if report_spec.report_kind in ECONOMY_REPORT_KINDS: @@ -749,17 +786,17 @@ def _ensure_report_output_dual_write_state_in_transaction( report_spec_schema_version=report_spec_schema_version, ) self._sync_report_identity_in_transaction(tx, report_output, report_spec) - version_manifest = self._build_version_manifest( - report_output, - report_spec=report_spec, - simulation_1=simulation_1, - simulation_2=simulation_2, - version_manifest_overrides=version_manifest_overrides, - ) runs_descending = self._list_report_runs_descending( report_output_id, queryer=tx ) if not runs_descending: + version_manifest = self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) self._insert_bootstrap_report_run( tx, report_output, @@ -771,6 +808,24 @@ def _ensure_report_output_dual_write_state_in_transaction( ) else: mutable_run = self._select_mutable_run(report_output, runs_descending) + version_manifest = ( + self._build_existing_run_version_manifest( + mutable_run, + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + if mutable_run is not None + else self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + ) if mutable_run is not None: run_matches_parent = self._run_matches_parent( mutable_run, diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index 97a7027fa..157c5e1bd 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -62,7 +62,18 @@ def _find_existing_simulation_row( ).fetchone() return dict(row) if row is not None else None - def _build_version_manifest( + def _merge_version_manifest_overrides( + self, + version_manifest: dict[str, str | None], + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + merged_manifest = dict(version_manifest) + for key, value in (version_manifest_overrides or {}).items(): + if key in merged_manifest and value is not None: + merged_manifest[key] = value + return merged_manifest + + def _build_bootstrap_version_manifest( self, simulation: dict, version_manifest_overrides: dict[str, str | None] | None = None, @@ -74,10 +85,27 @@ def _build_version_manifest( "runtime_app_name": None, "simulation_cache_version": None, } - for key, value in (version_manifest_overrides or {}).items(): - if key in version_manifest and value is not None: - version_manifest[key] = value - return version_manifest + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) + + def _build_existing_run_version_manifest( + self, + run: dict, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + version_manifest = { + "country_package_version": run.get("country_package_version"), + "policyengine_version": run.get("policyengine_version"), + "data_version": run.get("data_version"), + "runtime_app_name": run.get("runtime_app_name"), + "simulation_cache_version": run.get("simulation_cache_version"), + } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) def _list_simulation_runs_descending( self, simulation_id: int, *, queryer=None @@ -142,12 +170,8 @@ def _run_matches_parent( run: dict, simulation: dict, simulation_spec: SimulationSpec, - version_manifest_overrides: dict[str, str | None] | None = None, + version_manifest: dict[str, str | None], ) -> bool: - version_manifest = self._build_version_manifest( - simulation, - version_manifest_overrides=version_manifest_overrides, - ) return ( run["status"] == simulation["status"] and run.get("output") == simulation.get("output") @@ -168,12 +192,8 @@ def _insert_bootstrap_run( tx, simulation: dict, simulation_spec: SimulationSpec, - version_manifest_overrides: dict[str, str | None] | None = None, + version_manifest: dict[str, str | None], ) -> None: - version_manifest = self._build_version_manifest( - simulation, - version_manifest_overrides=version_manifest_overrides, - ) tx.query( """ INSERT INTO simulation_runs ( @@ -213,12 +233,8 @@ def _update_simulation_run_in_transaction( run_id: str, simulation: dict, simulation_spec: SimulationSpec, - version_manifest_overrides: dict[str, str | None] | None = None, + version_manifest: dict[str, str | None], ) -> None: - version_manifest = self._build_version_manifest( - simulation, - version_manifest_overrides=version_manifest_overrides, - ) tx.query( """ UPDATE simulation_runs @@ -292,29 +308,44 @@ def _ensure_simulation_dual_write_state_in_transaction( simulation_id, queryer=tx ) if not runs_descending: + version_manifest = self._build_bootstrap_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) self._insert_bootstrap_run( tx, simulation, simulation_spec, - version_manifest_overrides=version_manifest_overrides, + version_manifest=version_manifest, ) runs_descending = self._list_simulation_runs_descending( simulation_id, queryer=tx ) else: mutable_run = self._select_mutable_run(simulation, runs_descending) + version_manifest = ( + self._build_existing_run_version_manifest( + mutable_run, + version_manifest_overrides=version_manifest_overrides, + ) + if mutable_run is not None + else self._build_bootstrap_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) + ) if mutable_run is not None and not self._run_matches_parent( mutable_run, simulation, simulation_spec, - version_manifest_overrides=version_manifest_overrides, + version_manifest=version_manifest, ): self._update_simulation_run_in_transaction( tx, run_id=mutable_run["id"], simulation=simulation, simulation_spec=simulation_spec, - version_manifest_overrides=version_manifest_overrides, + version_manifest=version_manifest, ) runs_descending = self._list_simulation_runs_descending( simulation_id, queryer=tx diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 1e48c52b5..5be4389d7 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -1688,6 +1688,97 @@ def test_update_report_output_preserves_stored_explicit_report_spec(self, test_d assert snapshot["target"] == "cliff" assert snapshot["options"] == {"view": "tax"} + def test_update_report_output_preserves_existing_run_metadata_without_overrides( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/az", + population_type="geography", + policy_id=63, + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.621.0", + "policyengine_version": "0.95.0", + "data_version": "2026.04.17", + "runtime_app_name": "policyengine-app-v2", + "resolved_dataset": "enhanced_us_household", + }, + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="error", + error_message="later failure", + ) + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + assert run["status"] == "error" + assert run["error_message"] == "later failure" + assert run["country_package_version"] == "1.621.0" + assert run["policyengine_version"] == "0.95.0" + assert run["data_version"] == "2026.04.17" + assert run["runtime_app_name"] == "policyengine-app-v2" + assert run["resolved_dataset"] == "enhanced_us_household" + + def test_update_report_output_allows_explicit_metadata_override_on_existing_run( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nm", + population_type="geography", + policy_id=64, + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.621.0", + "policyengine_version": "0.95.0", + }, + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + version_manifest_overrides={ + "policyengine_version": "0.95.1", + }, + ) + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + assert run["country_package_version"] == "1.621.0" + assert run["policyengine_version"] == "0.95.1" + def test_update_report_output_bootstraps_missing_run_state(self, test_db): simulation_1 = simulation_service.create_simulation( country_id="us", diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index 8c442df03..8117488f7 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -520,3 +520,80 @@ def test_update_simulation_with_no_user_fields_returns_false(self, test_db): ).fetchone() assert post_row["api_version"] == pre_row["api_version"] assert post_row["status"] == pre_row["status"] + + def test_update_simulation_preserves_existing_run_metadata_without_overrides( + self, test_db + ): + created_simulation = service.create_simulation( + country_id="us", + population_id="household_metadata_preserve", + population_type="household", + policy_id=16, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + "data_version": "2026.04.16", + "runtime_app_name": "policyengine-app-v2", + }, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="error", + error_message="later failure", + ) + + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (created_simulation["id"],), + ).fetchone() + assert run["status"] == "error" + assert run["error_message"] == "later failure" + assert run["country_package_version"] == "1.620.0" + assert run["policyengine_version"] == "0.94.2" + assert run["data_version"] == "2026.04.16" + assert run["runtime_app_name"] == "policyengine-app-v2" + + def test_update_simulation_allows_explicit_metadata_override_on_existing_run( + self, test_db + ): + created_simulation = service.create_simulation( + country_id="us", + population_id="household_metadata_override", + population_type="household", + policy_id=17, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + }, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + version_manifest_overrides={ + "policyengine_version": "0.95.0", + }, + ) + + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (created_simulation["id"],), + ).fetchone() + assert run["country_package_version"] == "1.620.0" + assert run["policyengine_version"] == "0.95.0" From de0cde17eaa59315da466e734e1c71b04239187e Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 22 Apr 2026 23:31:01 +0200 Subject: [PATCH 07/17] Validate report aliases by canonical identity --- .../services/report_output_alias_service.py | 53 ++++++++-- .../test_report_output_alias_service.py | 100 ++++++++++++++++-- .../services/test_report_output_service.py | 7 +- 3 files changed, 139 insertions(+), 21 deletions(-) diff --git a/policyengine_api/services/report_output_alias_service.py b/policyengine_api/services/report_output_alias_service.py index 9440cfdfd..6a120ba6c 100644 --- a/policyengine_api/services/report_output_alias_service.py +++ b/policyengine_api/services/report_output_alias_service.py @@ -7,7 +7,7 @@ class ReportOutputAliasService: def _get_report_output_row(self, report_output_id: int) -> dict | None: row: Row | None = database.query( """ - SELECT id, country_id, simulation_1_id, simulation_2_id, year + SELECT id, country_id, report_identity_hash, report_identity_schema_version FROM report_outputs WHERE id = ? """, @@ -15,6 +15,45 @@ def _get_report_output_row(self, report_output_id: int) -> dict | None: ).fetchone() return dict(row) if row is not None else None + def _validate_alias_identity_compatibility( + self, + legacy_report_output: dict, + canonical_report_output: dict, + ) -> None: + if legacy_report_output["country_id"] != canonical_report_output["country_id"]: + raise ValueError( + "Legacy and canonical report outputs must describe the same report" + ) + + if ( + legacy_report_output["report_identity_hash"] is None + or legacy_report_output["report_identity_schema_version"] is None + ): + raise ValueError( + "Legacy report output must have canonical report identity before " + "aliasing" + ) + + if ( + canonical_report_output["report_identity_hash"] is None + or canonical_report_output["report_identity_schema_version"] is None + ): + raise ValueError( + "Canonical report output must have canonical report identity before " + "aliasing" + ) + + if ( + legacy_report_output["report_identity_hash"] + != canonical_report_output["report_identity_hash"] + or legacy_report_output["report_identity_schema_version"] + != canonical_report_output["report_identity_schema_version"] + ): + raise ValueError( + "Legacy and canonical report outputs must share canonical report " + "identity" + ) + def get_alias(self, legacy_report_output_id: int) -> dict | None: row: Row | None = database.query( """ @@ -78,14 +117,10 @@ def set_alias( f"#{existing_alias['canonical_report_output_id']}" ) - logical_key = ("country_id", "simulation_1_id", "simulation_2_id", "year") - if any( - legacy_report_output[field] != canonical_report_output[field] - for field in logical_key - ): - raise ValueError( - "Legacy and canonical report outputs must describe the same report" - ) + self._validate_alias_identity_compatibility( + legacy_report_output, + canonical_report_output, + ) database.query( """ INSERT INTO legacy_report_output_aliases diff --git a/tests/unit/services/test_report_output_alias_service.py b/tests/unit/services/test_report_output_alias_service.py index e4e28c916..d1d65002a 100644 --- a/tests/unit/services/test_report_output_alias_service.py +++ b/tests/unit/services/test_report_output_alias_service.py @@ -18,12 +18,15 @@ def _insert_legacy_report_output( legacy_report_output_id: int, canonical_report: dict, api_version: str = "legacy-version", + report_identity_hash: str | None = None, + report_identity_schema_version: int | None = None, ) -> None: test_db.query( """ INSERT INTO report_outputs ( - id, country_id, simulation_1_id, simulation_2_id, api_version, status, year - ) VALUES (?, ?, ?, ?, ?, ?, ?) + id, country_id, simulation_1_id, simulation_2_id, api_version, status, year, + report_identity_hash, report_identity_schema_version + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( legacy_report_output_id, @@ -33,6 +36,10 @@ def _insert_legacy_report_output( api_version, canonical_report["status"], canonical_report["year"], + report_identity_hash or canonical_report.get("report_identity_hash"), + report_identity_schema_version + if report_identity_schema_version is not None + else canonical_report.get("report_identity_schema_version"), ), ) @@ -194,9 +201,73 @@ def test_rejects_alias_when_legacy_report_output_is_missing(self, test_db): assert "Legacy report output #10030 not found" in str(exc_info.value) - def test_rejects_alias_when_legacy_and_canonical_reports_do_not_match( + def test_rejects_alias_when_reports_do_not_share_canonical_identity( self, test_db ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=34, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=35, + ) + default_report_spec = report_output_service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + cliff_report_spec = report_output_service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=default_report_spec, + report_spec_schema_version=1, + ) + distinct_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=cliff_report_spec, + report_spec_schema_version=1, + ) + + with pytest.raises(ValueError) as exc_info: + alias_service.set_alias( + legacy_report_output_id=distinct_report["id"], + canonical_report_output_id=canonical_report["id"], + ) + + assert "must share canonical report identity" in str(exc_info.value) + + def test_rejects_alias_when_legacy_report_output_has_no_identity(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_4b", @@ -209,20 +280,29 @@ def test_rejects_alias_when_legacy_and_canonical_reports_do_not_match( simulation_2_id=None, year="2025", ) - mismatched_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2026", + self._insert_legacy_report_output( + test_db, + legacy_report_output_id=10031, + canonical_report=canonical_report, + report_identity_hash=None, + report_identity_schema_version=None, + ) + test_db.query( + """ + UPDATE report_outputs + SET report_identity_hash = NULL, report_identity_schema_version = NULL + WHERE id = ? + """, + (10031,), ) with pytest.raises(ValueError) as exc_info: alias_service.set_alias( - legacy_report_output_id=mismatched_report["id"], + legacy_report_output_id=10031, canonical_report_output_id=canonical_report["id"], ) - assert "must describe the same report" in str(exc_info.value) + assert "must have canonical report identity" in str(exc_info.value) def test_rejects_alias_when_legacy_and_canonical_ids_match(self, test_db): simulation = simulation_service.create_simulation( diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 5be4389d7..e9e27881d 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -1378,8 +1378,9 @@ def test_get_report_output_resolves_alias_to_canonical_parent_and_display_run( test_db.query( """ INSERT INTO report_outputs ( - id, country_id, simulation_1_id, simulation_2_id, status, output, api_version, year - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + id, country_id, simulation_1_id, simulation_2_id, status, output, api_version, year, + report_identity_hash, report_identity_schema_version + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( 999, @@ -1390,6 +1391,8 @@ def test_get_report_output_resolves_alias_to_canonical_parent_and_display_run( json.dumps({"legacy": True}), "r0legacy1", "2025", + canonical_report["report_identity_hash"], + canonical_report["report_identity_schema_version"], ), ) alias_service.set_alias( From 72681ad3057f90ad0a0d62a0e683cfc217b92c8f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 22 Apr 2026 23:43:52 +0200 Subject: [PATCH 08/17] Add coverage for report identity and run metadata fixes --- changelog.d/3500.changed.md | 1 + tests/unit/test_stage5_routes.py | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 changelog.d/3500.changed.md diff --git a/changelog.d/3500.changed.md b/changelog.d/3500.changed.md new file mode 100644 index 000000000..c0dc238f0 --- /dev/null +++ b/changelog.d/3500.changed.md @@ -0,0 +1 @@ +Preserve explicit report definitions and execution metadata across later syncs, key new report creation and alias validation by canonical report identity, and resolve report reads through canonical parents plus display-run selection. diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index cc2631e55..34fac32f3 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -542,6 +542,54 @@ def test_get_report_output_alias_resolves_to_canonical_display_run(test_db): assert payload["result"]["api_version"] == get_report_output_cache_version("us") +def test_get_report_output_reads_malformed_legacy_row_without_runs_or_identity( + test_db, +): + household_simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_legacy_malformed", + population_type="household", + policy_id=58, + ) + geography_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/co", + population_type="geography", + policy_id=59, + ) + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, output, year + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "us", + household_simulation["id"], + geography_simulation["id"], + "r0legacy-malformed", + "error", + json.dumps({"result": "legacy-malformed"}), + "2025", + ), + ) + malformed_report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + client = create_test_client() + response = client.get(f"/us/report/{malformed_report['id']}") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == malformed_report["id"] + assert payload["result"]["status"] == "error" + assert payload["result"]["output"] == json.dumps( + {"result": "legacy-malformed"} + ) + assert payload["result"]["api_version"] == "r0legacy-malformed" + + def test_patch_report_output_wrong_country_returns_not_found_and_does_not_mutate( test_db, ): From 11fcff3ddfeb4049857bf4c1a0b23f21694d4740 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 8 May 2026 02:26:31 +0200 Subject: [PATCH 09/17] Resolve stage 6 rebase validation issues --- tests/unit/services/test_report_output_service.py | 7 ++++++- uv.lock | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index e9e27881d..42e3c88e0 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -1051,7 +1051,12 @@ def test_get_report_output_does_not_rewrite_terminal_active_run_for_running_pare "SELECT * FROM report_output_runs WHERE id = ?", (successful_run_id,), ).fetchone() - assert result["status"] == "running" + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + assert result["status"] == "complete" + assert stored_report["status"] == "running" assert successful_run["status"] == "complete" assert successful_run["output"] == output_json assert successful_run["finished_at"] is not None diff --git a/uv.lock b/uv.lock index 480fe985f..e16464992 100644 --- a/uv.lock +++ b/uv.lock @@ -2622,7 +2622,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/a0/f3/eeea7dab690e46cd9 [[package]] name = "policyengine-api" -version = "3.40.11" +version = "3.40.12" source = { editable = "." } dependencies = [ { name = "anthropic" }, From f10ecb1b53724a15e2230827e081d0dc5e3faf13 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 15:17:40 +0200 Subject: [PATCH 10/17] Implement canonical report runs and legacy ID map --- changelog.d/3500.changed.md | 2 +- policyengine_api/data/initialise.sql | 13 +- policyengine_api/data/initialise_local.sql | 15 +- .../routes/report_output_routes.py | 58 +- policyengine_api/routes/simulation_routes.py | 5 + .../services/report_output_alias_service.py | 132 ---- .../services/report_output_id_map_service.py | 185 +++++ .../services/report_output_service.py | 724 ++++++++++++++++-- .../services/simulation_service.py | 244 +++++- tests/unit/data/test_run_schema.py | 6 +- .../routes/test_route_exception_handling.py | 4 +- ...y => test_report_output_id_map_service.py} | 74 +- .../services/test_report_output_service.py | 127 ++- .../unit/services/test_simulation_service.py | 45 +- tests/unit/test_stage5_routes.py | 322 +++++++- 15 files changed, 1628 insertions(+), 328 deletions(-) delete mode 100644 policyengine_api/services/report_output_alias_service.py create mode 100644 policyengine_api/services/report_output_id_map_service.py rename tests/unit/services/{test_report_output_alias_service.py => test_report_output_id_map_service.py} (83%) diff --git a/changelog.d/3500.changed.md b/changelog.d/3500.changed.md index c0dc238f0..e54fc3e08 100644 --- a/changelog.d/3500.changed.md +++ b/changelog.d/3500.changed.md @@ -1 +1 @@ -Preserve explicit report definitions and execution metadata across later syncs, key new report creation and alias validation by canonical report identity, and resolve report reads through canonical parents plus display-run selection. +Preserve explicit report definitions and execution metadata, key report creation by canonical report identity, resolve legacy report IDs through a permanent compatibility map, and add run-targeted report and simulation rerun updates. diff --git a/policyengine_api/data/initialise.sql b/policyengine_api/data/initialise.sql index 4aacc9112..917ca0cf0 100644 --- a/policyengine_api/data/initialise.sql +++ b/policyengine_api/data/initialise.sql @@ -141,6 +141,11 @@ CREATE TABLE IF NOT EXISTS report_outputs ( latest_successful_run_id CHAR(36) DEFAULT NULL ); +CREATE INDEX report_outputs_identity_idx + ON report_outputs ( + country_id, report_identity_hash, report_identity_schema_version + ); + CREATE TABLE IF NOT EXISTS report_output_runs ( id CHAR(36) PRIMARY KEY, report_output_id INT NOT NULL, @@ -189,7 +194,13 @@ CREATE TABLE IF NOT EXISTS simulation_runs ( UNIQUE KEY simulation_run_sequence_idx (simulation_id, run_sequence) ); -CREATE TABLE IF NOT EXISTS legacy_report_output_aliases ( +CREATE INDEX simulation_runs_report_output_run_idx + ON simulation_runs (report_output_run_id); + +CREATE TABLE IF NOT EXISTS legacy_report_output_id_map ( legacy_report_output_id INT PRIMARY KEY, canonical_report_output_id INT NOT NULL ); + +CREATE INDEX legacy_report_output_id_map_canonical_idx + ON legacy_report_output_id_map (canonical_report_output_id); diff --git a/policyengine_api/data/initialise_local.sql b/policyengine_api/data/initialise_local.sql index b8530be65..d92930257 100644 --- a/policyengine_api/data/initialise_local.sql +++ b/policyengine_api/data/initialise_local.sql @@ -8,7 +8,7 @@ DROP TABLE IF EXISTS user_policies; DROP TABLE IF EXISTS tracers; DROP TABLE IF EXISTS report_output_runs; DROP TABLE IF EXISTS simulation_runs; -DROP TABLE IF EXISTS legacy_report_output_aliases; +DROP TABLE IF EXISTS legacy_report_output_id_map; CREATE TABLE IF NOT EXISTS household ( id INTEGER PRIMARY KEY, @@ -153,6 +153,11 @@ CREATE TABLE IF NOT EXISTS report_outputs ( latest_successful_run_id CHAR(36) DEFAULT NULL ); +CREATE INDEX report_outputs_identity_idx + ON report_outputs ( + country_id, report_identity_hash, report_identity_schema_version + ); + CREATE TABLE IF NOT EXISTS report_output_runs ( id CHAR(36) PRIMARY KEY, report_output_id INT NOT NULL, @@ -201,7 +206,13 @@ CREATE TABLE IF NOT EXISTS simulation_runs ( UNIQUE (simulation_id, run_sequence) ); -CREATE TABLE IF NOT EXISTS legacy_report_output_aliases ( +CREATE INDEX simulation_runs_report_output_run_idx + ON simulation_runs (report_output_run_id); + +CREATE TABLE IF NOT EXISTS legacy_report_output_id_map ( legacy_report_output_id INT PRIMARY KEY, canonical_report_output_id INT NOT NULL ); + +CREATE INDEX legacy_report_output_id_map_canonical_idx + ON legacy_report_output_id_map (canonical_report_output_id); diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 475b016c1..cfafca18d 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -195,6 +195,53 @@ def get_report_output(country_id: str, report_id: int) -> Response: ) +@report_output_bp.route("//report//rerun", methods=["POST"]) +@validate_country +def create_report_rerun(country_id: str, report_id: int) -> Response: + """ + Create a new pending run for an existing report. + + The requested report ID may be a legacy ID; the run is always created under + the resolved canonical report output. + """ + payload = request.json or {} + if not isinstance(payload, dict): + raise BadRequest("Payload must be an object") + + version_manifest_overrides = _parse_report_run_metadata(payload) + + try: + if not report_output_service.report_output_exists(country_id, report_id): + raise NotFound(f"Report #{report_id} not found.") + + rerun = report_output_service.create_report_rerun( + country_id=country_id, + report_output_id=report_id, + version_manifest_overrides=version_manifest_overrides, + ) + except HTTPException: + raise + except ValueError as e: + current_app.logger.warning( + "Bad request creating report rerun #%s for country %s: %s", + report_id, + country_id, + e, + ) + raise BadRequest(f"Failed to create report rerun: {e}") from e + + response_body = dict( + status="ok", + message="Report rerun created successfully", + result=rerun, + ) + return Response( + json.dumps(response_body), + status=201, + mimetype="application/json", + ) + + @report_output_bp.route("//report", methods=["PATCH"]) @validate_country def update_report_output(country_id: str) -> Response: @@ -206,6 +253,7 @@ def update_report_output(country_id: str) -> Response: Request body can contain: - id (int): The report output ID. + - report_output_run_id (str | None): Specific report run to update. - status (str): The new status ('pending', 'running', 'complete', or 'error') - output (dict): The result output (for complete status) - api_version (str): The API version of the report @@ -221,6 +269,7 @@ def update_report_output(country_id: str) -> Response: report_id = payload.get("id") output = payload.get("output") error_message = payload.get("error_message") + report_output_run_id = payload.get("report_output_run_id") version_manifest_overrides = _parse_report_run_metadata(payload) print(f"Updating report #{report_id} for country {country_id}") @@ -236,6 +285,8 @@ def update_report_output(country_id: str) -> Response: # Validate that complete status has output if status == "complete" and output is None: raise BadRequest("output is required when status is 'complete'") + if report_output_run_id is not None and not isinstance(report_output_run_id, str): + raise BadRequest("report_output_run_id must be a string") try: # First check if the report output exists without running pointer sync: @@ -251,17 +302,14 @@ def update_report_output(country_id: str) -> Response: status=status, output=output, error_message=error_message, + report_output_run_id=report_output_run_id, version_manifest_overrides=version_manifest_overrides, ) if not success: raise BadRequest("No fields to update") - # Get the updated stored record so stale-runtime jobs do not appear to - # complete the current runtime lineage in the PATCH response. - updated_report = report_output_service.get_stored_report_output( - country_id, report_id - ) + updated_report = report_output_service.get_report_output(country_id, report_id) response_body = dict( status="ok", diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index 2da34a962..19cafe473 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -178,6 +178,7 @@ def update_simulation(country_id: str) -> Response: Request body can contain: - id (int): The simulation ID. + - simulation_run_id (str | None): Specific simulation run to update. - status (str): The new status ('complete' or 'error') - output (dict): The result output (for complete status) - api_version (str): The API version of the simulation @@ -193,6 +194,7 @@ def update_simulation(country_id: str) -> Response: simulation_id = payload.get("id") output = payload.get("output") error_message = payload.get("error_message") + simulation_run_id = payload.get("simulation_run_id") version_manifest_overrides = _parse_simulation_run_metadata(payload) print(f"Updating simulation #{simulation_id} for country {country_id}") @@ -203,6 +205,8 @@ def update_simulation(country_id: str) -> Response: # Validate that complete status has output if status == "complete" and output is None: raise BadRequest("output is required when status is 'complete'") + if simulation_run_id is not None and not isinstance(simulation_run_id, str): + raise BadRequest("simulation_run_id must be a string") try: # First check if the simulation exists @@ -219,6 +223,7 @@ def update_simulation(country_id: str) -> Response: status=status, output=output, error_message=error_message, + simulation_run_id=simulation_run_id, version_manifest_overrides=version_manifest_overrides, ) diff --git a/policyengine_api/services/report_output_alias_service.py b/policyengine_api/services/report_output_alias_service.py deleted file mode 100644 index 6a120ba6c..000000000 --- a/policyengine_api/services/report_output_alias_service.py +++ /dev/null @@ -1,132 +0,0 @@ -from sqlalchemy.engine.row import Row - -from policyengine_api.data import database - - -class ReportOutputAliasService: - def _get_report_output_row(self, report_output_id: int) -> dict | None: - row: Row | None = database.query( - """ - SELECT id, country_id, report_identity_hash, report_identity_schema_version - FROM report_outputs - WHERE id = ? - """, - (report_output_id,), - ).fetchone() - return dict(row) if row is not None else None - - def _validate_alias_identity_compatibility( - self, - legacy_report_output: dict, - canonical_report_output: dict, - ) -> None: - if legacy_report_output["country_id"] != canonical_report_output["country_id"]: - raise ValueError( - "Legacy and canonical report outputs must describe the same report" - ) - - if ( - legacy_report_output["report_identity_hash"] is None - or legacy_report_output["report_identity_schema_version"] is None - ): - raise ValueError( - "Legacy report output must have canonical report identity before " - "aliasing" - ) - - if ( - canonical_report_output["report_identity_hash"] is None - or canonical_report_output["report_identity_schema_version"] is None - ): - raise ValueError( - "Canonical report output must have canonical report identity before " - "aliasing" - ) - - if ( - legacy_report_output["report_identity_hash"] - != canonical_report_output["report_identity_hash"] - or legacy_report_output["report_identity_schema_version"] - != canonical_report_output["report_identity_schema_version"] - ): - raise ValueError( - "Legacy and canonical report outputs must share canonical report " - "identity" - ) - - def get_alias(self, legacy_report_output_id: int) -> dict | None: - row: Row | None = database.query( - """ - SELECT * FROM legacy_report_output_aliases - WHERE legacy_report_output_id = ? - """, - (legacy_report_output_id,), - ).fetchone() - return dict(row) if row is not None else None - - def resolve_canonical_report_output_id( - self, requested_report_output_id: int - ) -> int | None: - alias = self.get_alias(requested_report_output_id) - if alias is not None: - canonical_report_output_id = alias["canonical_report_output_id"] - if self._get_report_output_row(canonical_report_output_id) is None: - raise ValueError( - "Alias points to missing canonical report output " - f"#{canonical_report_output_id}" - ) - return canonical_report_output_id - - row: Row | None = database.query( - "SELECT id FROM report_outputs WHERE id = ?", - (requested_report_output_id,), - ).fetchone() - return row["id"] if row is not None else None - - def set_alias( - self, - legacy_report_output_id: int, - canonical_report_output_id: int, - ) -> bool: - legacy_report_output = self._get_report_output_row(legacy_report_output_id) - if legacy_report_output is None: - raise ValueError( - f"Legacy report output #{legacy_report_output_id} not found" - ) - - canonical_report_output = self._get_report_output_row( - canonical_report_output_id - ) - if canonical_report_output is None: - raise ValueError( - f"Canonical report output #{canonical_report_output_id} not found" - ) - if legacy_report_output_id == canonical_report_output_id: - raise ValueError("Legacy and canonical report outputs must be different") - - existing_alias = self.get_alias(legacy_report_output_id) - if existing_alias is not None: - if ( - existing_alias["canonical_report_output_id"] - == canonical_report_output_id - ): - return True - - raise ValueError( - "Legacy report output alias already points to canonical report output " - f"#{existing_alias['canonical_report_output_id']}" - ) - - self._validate_alias_identity_compatibility( - legacy_report_output, - canonical_report_output, - ) - database.query( - """ - INSERT INTO legacy_report_output_aliases - (legacy_report_output_id, canonical_report_output_id) - VALUES (?, ?) - """, - (legacy_report_output_id, canonical_report_output_id), - ) - return True diff --git a/policyengine_api/services/report_output_id_map_service.py b/policyengine_api/services/report_output_id_map_service.py new file mode 100644 index 000000000..80b341f3a --- /dev/null +++ b/policyengine_api/services/report_output_id_map_service.py @@ -0,0 +1,185 @@ +from sqlalchemy.engine.row import Row + +from policyengine_api.data import database + + +class ReportOutputIdMapService: + def _get_report_output_row( + self, + report_output_id: int, + *, + queryer=None, + country_id: str | None = None, + ) -> dict | None: + queryer = queryer or database + query = """ + SELECT id, country_id, report_identity_hash, + report_identity_schema_version + FROM report_outputs + WHERE id = ? + """ + params: list[int | str] = [report_output_id] + if country_id is not None: + query += " AND country_id = ?" + params.append(country_id) + + row: Row | None = queryer.query(query, tuple(params)).fetchone() + return dict(row) if row is not None else None + + def _validate_mapping_identity_compatibility( + self, + legacy_report_output: dict, + canonical_report_output: dict, + ) -> None: + if legacy_report_output["country_id"] != canonical_report_output["country_id"]: + raise ValueError( + "Legacy and canonical report outputs must describe the same report" + ) + + if ( + legacy_report_output["report_identity_hash"] is None + or legacy_report_output["report_identity_schema_version"] is None + ): + raise ValueError( + "Legacy report output must have canonical report identity before " + "mapping" + ) + + if ( + canonical_report_output["report_identity_hash"] is None + or canonical_report_output["report_identity_schema_version"] is None + ): + raise ValueError( + "Canonical report output must have canonical report identity before " + "mapping" + ) + + if ( + legacy_report_output["report_identity_hash"] + != canonical_report_output["report_identity_hash"] + or legacy_report_output["report_identity_schema_version"] + != canonical_report_output["report_identity_schema_version"] + ): + raise ValueError( + "Legacy and canonical report outputs must share canonical report " + "identity" + ) + + def get_mapping( + self, + legacy_report_output_id: int, + *, + queryer=None, + ) -> dict | None: + queryer = queryer or database + row: Row | None = queryer.query( + """ + SELECT * FROM legacy_report_output_id_map + WHERE legacy_report_output_id = ? + """, + (legacy_report_output_id,), + ).fetchone() + return dict(row) if row is not None else None + + def resolve_report_output_id( + self, + requested_report_output_id: int, + *, + queryer=None, + country_id: str | None = None, + ) -> dict | None: + queryer = queryer or database + mapping = self.get_mapping(requested_report_output_id, queryer=queryer) + if mapping is not None: + canonical_report_output_id = mapping["canonical_report_output_id"] + canonical_report_output = self._get_report_output_row( + canonical_report_output_id, + queryer=queryer, + country_id=country_id, + ) + if canonical_report_output is None: + raise ValueError( + "Legacy ID mapping points to missing canonical report output " + f"#{canonical_report_output_id}" + ) + return { + "requested_report_output_id": requested_report_output_id, + "canonical_report_output_id": canonical_report_output_id, + "is_legacy_id": True, + } + + requested_report_output = self._get_report_output_row( + requested_report_output_id, + queryer=queryer, + country_id=country_id, + ) + if requested_report_output is None: + return None + + return { + "requested_report_output_id": requested_report_output_id, + "canonical_report_output_id": requested_report_output_id, + "is_legacy_id": False, + } + + def resolve_canonical_report_output_id( + self, + requested_report_output_id: int, + *, + queryer=None, + country_id: str | None = None, + ) -> int | None: + resolution = self.resolve_report_output_id( + requested_report_output_id, + queryer=queryer, + country_id=country_id, + ) + if resolution is None: + return None + return resolution["canonical_report_output_id"] + + def set_mapping( + self, + legacy_report_output_id: int, + canonical_report_output_id: int, + ) -> bool: + if legacy_report_output_id == canonical_report_output_id: + raise ValueError("Legacy and canonical report outputs must be different") + + canonical_report_output = self._get_report_output_row( + canonical_report_output_id + ) + if canonical_report_output is None: + raise ValueError( + f"Canonical report output #{canonical_report_output_id} not found" + ) + + existing_mapping = self.get_mapping(legacy_report_output_id) + if existing_mapping is not None: + if ( + existing_mapping["canonical_report_output_id"] + == canonical_report_output_id + ): + return True + + raise ValueError( + "Legacy report output ID already maps to canonical report output " + f"#{existing_mapping['canonical_report_output_id']}" + ) + + legacy_report_output = self._get_report_output_row(legacy_report_output_id) + if legacy_report_output is not None: + self._validate_mapping_identity_compatibility( + legacy_report_output, + canonical_report_output, + ) + + database.query( + """ + INSERT INTO legacy_report_output_id_map + (legacy_report_output_id, canonical_report_output_id) + VALUES (?, ?) + """, + (legacy_report_output_id, canonical_report_output_id), + ) + return True diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 9a7f9770d..246ec764c 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -5,8 +5,8 @@ from policyengine_api.constants import get_report_output_cache_version from policyengine_api.data import database -from policyengine_api.services.report_output_alias_service import ( - ReportOutputAliasService, +from policyengine_api.services.report_output_id_map_service import ( + ReportOutputIdMapService, ) from policyengine_api.services.report_run_service import ReportRunService from policyengine_api.services.report_spec_service import ( @@ -29,7 +29,7 @@ class ReportOutputService: def __init__(self): self.report_spec_service = ReportSpecService() self.simulation_service = SimulationService() - self.report_output_alias_service = ReportOutputAliasService() + self.report_output_id_map_service = ReportOutputIdMapService() self.report_run_service = ReportRunService() def _lock_clause(self) -> str: @@ -97,6 +97,17 @@ def _get_report_output_row( row: Row | None = queryer.query(query, tuple(params)).fetchone() return dict(row) if row is not None else None + def _get_last_inserted_report_output_id(self, tx) -> int: + query = ( + "SELECT last_insert_rowid() AS id" + if database.local + else "SELECT LAST_INSERT_ID() AS id" + ) + row = tx.query(query).fetchone() + if row is None or row["id"] is None: + raise Exception("Failed to retrieve inserted report output ID") + return int(row["id"]) + def _get_linked_simulations( self, report_output: dict, @@ -185,6 +196,33 @@ def _list_report_runs_descending( runs.append(run) return runs + def _get_report_run_row( + self, + run_id: str, + *, + queryer=None, + report_output_id: int | None = None, + for_update: bool = False, + ) -> dict | None: + queryer = queryer or database + query = "SELECT * FROM report_output_runs WHERE id = ?" + params: list[str | int] = [run_id] + if report_output_id is not None: + query += " AND report_output_id = ?" + params.append(report_output_id) + if for_update: + query += self._lock_clause() + + row: Row | None = queryer.query(query, tuple(params)).fetchone() + if row is None: + return None + + run = dict(row) + run["report_spec_snapshot_json"] = parse_json_field( + run.get("report_spec_snapshot_json") + ) + return run + def _select_mutable_run( self, report_output: dict, runs_descending: list[dict] ) -> dict | None: @@ -644,6 +682,164 @@ def _insert_bootstrap_report_run( ), ) + def _insert_report_run_in_transaction( + self, + tx, + report_output: dict, + *, + status: str, + trigger_type: str, + source_run_id: str | None, + report_spec: ReportSpec | None, + version_manifest: dict[str, str | None], + ) -> str: + run_sequence_row: Row | None = tx.query( + """ + SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence + FROM report_output_runs + WHERE report_output_id = ? + """, + (report_output["id"],), + ).fetchone() + run_sequence = ( + int(run_sequence_row["max_run_sequence"]) + 1 + if run_sequence_row is not None + else 1 + ) + run_id = str(uuid.uuid4()) + requested_at = self._utc_timestamp() + is_terminal = status in ("complete", "error") + has_started = status in ("running", "complete", "error") + started_at = requested_at if has_started else None + finished_at = requested_at if is_terminal else None + + tx.query( + """ + INSERT INTO report_output_runs ( + id, report_output_id, run_sequence, status, output, error_message, + trigger_type, requested_at, started_at, finished_at, source_run_id, + report_spec_snapshot_json, country_package_version, policyengine_version, + data_version, runtime_app_name, report_cache_version, + simulation_cache_version, requested_version_override, resolved_dataset, + resolved_options_hash + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + run_id, + report_output["id"], + run_sequence, + status, + None, + None, + trigger_type, + requested_at, + started_at, + finished_at, + source_run_id, + (report_spec.model_dump_json() if report_spec is not None else None), + version_manifest["country_package_version"], + version_manifest["policyengine_version"], + version_manifest["data_version"], + version_manifest["runtime_app_name"], + version_manifest["report_cache_version"], + version_manifest["simulation_cache_version"], + version_manifest["requested_version_override"], + version_manifest["resolved_dataset"], + version_manifest["resolved_options_hash"], + ), + ) + return run_id + + def _select_simulation_display_run( + self, + simulation: dict, + runs_descending: list[dict], + ) -> dict | None: + active_run_id = simulation.get("active_run_id") + if active_run_id is not None: + for run in runs_descending: + if run["id"] == active_run_id: + return run + + latest_successful_run_id = simulation.get("latest_successful_run_id") + if latest_successful_run_id is not None: + for run in runs_descending: + if run["id"] == latest_successful_run_id: + return run + + return runs_descending[0] if runs_descending else None + + def _insert_simulation_run_in_transaction( + self, + tx, + simulation: dict, + *, + report_output_run_id: str, + input_position: int, + source_run: dict | None, + ) -> str: + simulation_spec = ( + self.simulation_service._upsert_simulation_spec_in_transaction( + tx, + simulation, + ) + ) + version_manifest = ( + self.simulation_service._build_existing_run_version_manifest( + source_run, + simulation, + ) + if source_run is not None + else self.simulation_service._build_bootstrap_version_manifest(simulation) + ) + run_sequence_row: Row | None = tx.query( + """ + SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence + FROM simulation_runs + WHERE simulation_id = ? + """, + (simulation["id"],), + ).fetchone() + run_sequence = ( + int(run_sequence_row["max_run_sequence"]) + 1 + if run_sequence_row is not None + else 1 + ) + run_id = str(uuid.uuid4()) + tx.query( + """ + INSERT INTO simulation_runs ( + id, simulation_id, report_output_run_id, input_position, run_sequence, + status, output, error_message, trigger_type, requested_at, started_at, + finished_at, source_run_id, simulation_spec_snapshot_json, + country_package_version, policyengine_version, data_version, + runtime_app_name, simulation_cache_version + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + run_id, + simulation["id"], + report_output_run_id, + input_position, + run_sequence, + "pending", + None, + None, + "report_rerun", + self._utc_timestamp(), + None, + None, + source_run["id"] if source_run is not None else None, + simulation_spec.model_dump_json(), + version_manifest["country_package_version"], + version_manifest["policyengine_version"], + version_manifest["data_version"], + version_manifest["runtime_app_name"], + version_manifest["simulation_cache_version"], + ), + ) + return run_id + def _update_report_run_in_transaction( self, tx, @@ -743,6 +939,57 @@ def _sync_parent_pointers_in_transaction( report_output["active_run_id"] = desired_active_run_id report_output["latest_successful_run_id"] = desired_latest_successful_run_id + def _sync_parent_mirror_from_display_run_in_transaction( + self, + tx, + report_output: dict, + runs_descending: list[dict], + ) -> dict: + self._sync_parent_pointers_in_transaction(tx, report_output, runs_descending) + display_run = select_display_report_run(report_output, runs_descending) + if display_run is None: + refreshed_report_output = self._get_report_output_row( + report_output["id"], + queryer=tx, + country_id=report_output["country_id"], + ) + if refreshed_report_output is None: + raise ValueError( + f"Report output #{report_output['id']} not found after sync" + ) + return refreshed_report_output + + parent_api_version = ( + display_run["report_cache_version"] + if display_run.get("report_cache_version") is not None + else report_output.get("api_version") + ) + tx.query( + """ + UPDATE report_outputs + SET status = ?, output = ?, error_message = ?, api_version = ? + WHERE id = ? AND country_id = ? + """, + ( + display_run["status"], + serialize_json_field(display_run.get("output")), + display_run.get("error_message"), + parent_api_version, + report_output["id"], + report_output["country_id"], + ), + ) + refreshed_report_output = self._get_report_output_row( + report_output["id"], + queryer=tx, + country_id=report_output["country_id"], + ) + if refreshed_report_output is None: + raise ValueError( + f"Report output #{report_output['id']} not found after sync" + ) + return refreshed_report_output + def _ensure_report_output_dual_write_state_in_transaction( self, tx, @@ -895,7 +1142,7 @@ def get_stored_report_output( self, country_id: str, report_output_id: int ) -> dict | None: """ - Get a stored report output row without aliasing to current runtime lineage. + Get a stored report output row without resolving legacy ID mappings. This is used by mutation paths that must address the originally requested row. It still runs dual-write synchronization, so it may @@ -917,7 +1164,10 @@ def get_stored_report_output( def report_output_exists(self, country_id: str, report_output_id: int) -> bool: return ( - self._get_report_output_row(report_output_id, country_id=country_id) + self.report_output_id_map_service.resolve_report_output_id( + report_output_id, + country_id=country_id, + ) is not None ) @@ -961,7 +1211,7 @@ def _find_existing_report_output_row_by_identity( SELECT * FROM report_outputs WHERE country_id = ? AND report_identity_hash = ? AND report_identity_schema_version = ? - ORDER BY id DESC + ORDER BY id ASC """, ( country_id, @@ -991,7 +1241,7 @@ def _list_report_output_rows_by_legacy_key( params.append(simulation_2_id) else: query += " AND simulation_2_id IS NULL" - query += " ORDER BY id DESC" + query += " ORDER BY id ASC" rows = queryer.query(query, tuple(params)).fetchall() return [dict(row) for row in rows] @@ -1004,39 +1254,67 @@ def _build_report_spec_for_create( simulation_2_id: int | None, year: str, queryer=None, - ) -> ReportSpec | None: + ) -> ReportSpec: queryer = queryer or database - simulation_1 = self.simulation_service._get_simulation_row( - simulation_1_id, - queryer=queryer, + simulation_1 = self._require_simulation_exists( + queryer, country_id=country_id, + simulation_id=simulation_1_id, ) - if simulation_1 is None: - return None simulation_2 = None if simulation_2_id is not None: - simulation_2 = self.simulation_service._get_simulation_row( - simulation_2_id, - queryer=queryer, + simulation_2 = self._require_simulation_exists( + queryer, country_id=country_id, + simulation_id=simulation_2_id, ) - if simulation_2 is None: - return None - try: - return self.report_spec_service.build_report_spec( - report_output={ - "country_id": country_id, - "simulation_1_id": simulation_1_id, - "simulation_2_id": simulation_2_id, - "year": year, - }, - simulation_1=simulation_1, - simulation_2=simulation_2, + return self.report_spec_service.build_report_spec( + report_output={ + "country_id": country_id, + "simulation_1_id": simulation_1_id, + "simulation_2_id": simulation_2_id, + "year": year, + }, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + + def _validate_explicit_report_spec_for_create( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + report_spec: ReportSpec, + queryer, + ) -> None: + simulation_1 = self._require_simulation_exists( + queryer, + country_id=country_id, + simulation_id=simulation_1_id, + ) + simulation_2 = None + if simulation_2_id is not None: + simulation_2 = self._require_simulation_exists( + queryer, + country_id=country_id, + simulation_id=simulation_2_id, ) - except ValueError: - return None + + self.report_spec_service.validate_report_spec_matches_context( + { + "country_id": country_id, + "simulation_1_id": simulation_1_id, + "simulation_2_id": simulation_2_id, + "year": year, + }, + report_spec, + simulation_1, + simulation_2, + ) def _get_report_spec_for_identity_matching( self, @@ -1093,22 +1371,23 @@ def _find_existing_report_output_for_create( queryer=None, ) -> dict | None: queryer = queryer or database - identity_report_spec = report_spec or self._build_report_spec_for_create( - country_id=country_id, - simulation_1_id=simulation_1_id, - simulation_2_id=simulation_2_id, - year=year, - queryer=queryer, - ) - if identity_report_spec is None: - return self._find_existing_report_output_row( + if report_spec is not None: + self._validate_explicit_report_spec_for_create( country_id=country_id, simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, queryer=queryer, ) + identity_report_spec = report_spec or self._build_report_spec_for_create( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + queryer=queryer, + ) report_identity_hash, report_identity_schema_version = ( self.report_spec_service.get_report_identity(identity_report_spec) ) @@ -1146,10 +1425,12 @@ def _find_existing_report_output_for_create( return None - def _alias_report_output(self, report_output_id: int, report_output: dict) -> dict: - aliased_report = dict(report_output) - aliased_report["id"] = report_output_id - return aliased_report + def _with_requested_report_output_id( + self, report_output_id: int, report_output: dict + ) -> dict: + response_report = dict(report_output) + response_report["id"] = report_output_id + return response_report def _merge_display_run_into_report_output( self, @@ -1311,12 +1592,11 @@ def tx_callback(tx): ), ) - created_report = self._find_existing_report_output_row( - country_id=country_id, - simulation_1_id=simulation_1_id, - simulation_2_id=simulation_2_id, - year=year, + created_report_id = self._get_last_inserted_report_output_id(tx) + created_report = self._get_report_output_row( + created_report_id, queryer=tx, + country_id=country_id, ) if created_report is None: raise Exception("Failed to retrieve created report output") @@ -1348,14 +1628,14 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No f"Invalid report output ID: {report_output_id}. Must be a positive integer." ) - canonical_report_output_id = ( - self.report_output_alias_service.resolve_canonical_report_output_id( - report_output_id - ) + resolution = self.report_output_id_map_service.resolve_report_output_id( + report_output_id, + country_id=country_id, ) - if canonical_report_output_id is None: + if resolution is None: return None + canonical_report_output_id = resolution["canonical_report_output_id"] canonical_report_output = self._get_report_output_row( canonical_report_output_id, country_id=country_id, @@ -1384,8 +1664,8 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No canonical_report_output, display_run, ) - if report_output_id != canonical_report_output_id: - return self._alias_report_output( + if resolution["is_legacy_id"]: + return self._with_requested_report_output_id( report_output_id, resolved_report_output, ) @@ -1397,6 +1677,182 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No ) raise e + def create_report_rerun( + self, + country_id: str, + report_output_id: int, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict: + """ + Create a new pending run for the canonical report resolved from the + requested report ID. + """ + print(f"Creating report rerun for report output {report_output_id}") + + def tx_callback(tx): + resolution = self.report_output_id_map_service.resolve_report_output_id( + report_output_id, + queryer=tx, + country_id=country_id, + ) + if resolution is None: + raise ValueError(f"Report output #{report_output_id} not found") + + canonical_report_id = resolution["canonical_report_output_id"] + canonical_report = ( + self._ensure_report_output_dual_write_state_in_transaction( + tx, + canonical_report_id, + country_id=country_id, + ) + ) + canonical_report = self._get_report_output_row( + canonical_report_id, + queryer=tx, + country_id=country_id, + for_update=True, + ) + if canonical_report is None: + raise ValueError(f"Report output #{report_output_id} not found") + + simulation_1, simulation_2 = self._get_linked_simulations( + canonical_report, + queryer=tx, + bootstrap_dual_write_state=True, + ) + report_spec = self._upsert_report_spec_in_transaction( + tx, + canonical_report, + simulation_1, + simulation_2, + ) + existing_runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + source_report_run = select_display_report_run( + canonical_report, + existing_runs_descending, + ) + report_version_manifest = ( + self._build_existing_run_version_manifest( + source_report_run, + canonical_report, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + if source_report_run is not None + else self._build_bootstrap_version_manifest( + canonical_report, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + ) + report_run_id = self._insert_report_run_in_transaction( + tx, + canonical_report, + status="pending", + trigger_type="rerun", + source_run_id=( + source_report_run["id"] if source_report_run is not None else None + ), + report_spec=report_spec, + version_manifest=report_version_manifest, + ) + + simulation_run_ids: list[str] = [] + for input_position, simulation in ( + (1, simulation_1), + (2, simulation_2), + ): + if simulation is None: + continue + + simulation_runs_descending = ( + self.simulation_service._list_simulation_runs_descending( + simulation["id"], + queryer=tx, + ) + ) + source_simulation_run = self._select_simulation_display_run( + simulation, + simulation_runs_descending, + ) + simulation_run_id = self._insert_simulation_run_in_transaction( + tx, + simulation, + report_output_run_id=report_run_id, + input_position=input_position, + source_run=source_simulation_run, + ) + simulation_run_ids.append(simulation_run_id) + + simulation["status"] = "pending" + simulation["output"] = None + simulation["error_message"] = None + simulation_runs_descending = ( + self.simulation_service._list_simulation_runs_descending( + simulation["id"], + queryer=tx, + ) + ) + self.simulation_service._sync_parent_pointers_in_transaction( + tx, + simulation, + simulation_runs_descending, + ) + tx.query( + """ + UPDATE simulations + SET status = ?, output = ?, error_message = ? + WHERE id = ? AND country_id = ? + """, + ( + "pending", + None, + None, + simulation["id"], + country_id, + ), + ) + + canonical_report["status"] = "pending" + canonical_report["output"] = None + canonical_report["error_message"] = None + report_runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + refreshed_report = self._sync_parent_mirror_from_display_run_in_transaction( + tx, + canonical_report, + report_runs_descending, + ) + selected_report = self._merge_display_run_into_report_output( + refreshed_report, + self._get_report_run_row(report_run_id, queryer=tx), + ) + return { + "requested_report_output_id": report_output_id, + "report_output_id": canonical_report_id, + "report_output_run_id": report_run_id, + "simulation_run_ids": simulation_run_ids, + "report_spec": ( + report_spec.model_dump() if report_spec is not None else None + ), + "report": selected_report, + } + + try: + return database.transaction(tx_callback) + except Exception as e: + print(f"Error creating report rerun #{report_output_id}. Details: {str(e)}") + raise e + def update_report_output( self, country_id: str, @@ -1404,6 +1860,7 @@ def update_report_output( status: str | None = None, output: str | None = None, error_message: str | None = None, + report_output_run_id: str | None = None, version_manifest_overrides: dict[str, str | None] | None = None, ) -> bool: """ @@ -1412,54 +1869,153 @@ def update_report_output( print(f"Updating report output {report_id}") try: - update_fields = [] - update_values = [] - - if status is not None: - update_fields.append("status = ?") - update_values.append(status) - - if output is not None: - update_fields.append("output = ?") - update_values.append(output) - - if error_message is not None: - update_fields.append("error_message = ?") - update_values.append(error_message) - - if not update_fields and not version_manifest_overrides: + has_user_fields = ( + status is not None or output is not None or error_message is not None + ) + if not has_user_fields and not version_manifest_overrides: print("No fields to update") return False def tx_callback(tx): - requested_report = self._get_report_output_row( + resolution = self.report_output_id_map_service.resolve_report_output_id( report_id, queryer=tx, country_id=country_id, + ) + if resolution is None: + raise ValueError(f"Report output #{report_id} not found") + + canonical_report_id = resolution["canonical_report_output_id"] + canonical_report = self._get_report_output_row( + canonical_report_id, + queryer=tx, + country_id=country_id, for_update=True, ) - if requested_report is None: + if canonical_report is None: raise ValueError(f"Report output #{report_id} not found") - if status == "running" and not self._has_mutable_running_run( - requested_report, queryer=tx + try: + simulation_1, simulation_2 = self._get_linked_simulations( + canonical_report, + queryer=tx, + bootstrap_dual_write_state=True, + ) + except ValueError as exc: + print( + "Skipping linked simulation sync for report output " + f"#{canonical_report_id}. Details: {str(exc)}" + ) + simulation_1, simulation_2 = None, None + + report_spec = self._upsert_report_spec_in_transaction( + tx, + canonical_report, + simulation_1, + simulation_2, + ) + self._sync_report_identity_in_transaction( + tx, + canonical_report, + report_spec, + ) + + runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + if not runs_descending: + version_manifest = self._build_bootstrap_version_manifest( + canonical_report, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + self._insert_bootstrap_report_run( + tx, + canonical_report, + report_spec, + version_manifest, + ) + runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + + if report_output_run_id is not None: + mutable_run = self._get_report_run_row( + report_output_run_id, + queryer=tx, + report_output_id=canonical_report_id, + for_update=True, + ) + if mutable_run is None: + raise ValueError( + "Report output run " + f"#{report_output_run_id} not found for report " + f"#{canonical_report_id}" + ) + else: + mutable_run = self._select_mutable_run( + canonical_report, + runs_descending, + ) + + if mutable_run is None: + raise ValueError( + "Cannot update report output without an active report run" + ) + + if status == "running" and mutable_run["status"] not in ( + "pending", + "running", ): raise ValueError( "Cannot mark report output running without an active " "pending or running report run" ) - if update_fields: - tx.query( - f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", - (*update_values, report_id, country_id), - ) - self._ensure_report_output_dual_write_state_in_transaction( - tx, - report_id, - country_id=country_id, + run_update_state = dict(canonical_report) + run_update_state["status"] = ( + status if status is not None else mutable_run["status"] + ) + run_update_state["output"] = ( + output if output is not None else mutable_run.get("output") + ) + run_update_state["error_message"] = ( + error_message + if error_message is not None + else mutable_run.get("error_message") + ) + + version_manifest = self._build_existing_run_version_manifest( + mutable_run, + canonical_report, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, version_manifest_overrides=version_manifest_overrides, ) + self._update_report_run_in_transaction( + tx, + run_id=mutable_run["id"], + report_output=run_update_state, + report_spec=report_spec, + version_manifest=version_manifest, + ) + canonical_report["status"] = run_update_state["status"] + canonical_report["output"] = run_update_state["output"] + canonical_report["error_message"] = run_update_state["error_message"] + runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + self._sync_parent_mirror_from_display_run_in_transaction( + tx, + canonical_report, + runs_descending, + ) database.transaction(tx_callback) diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index 157c5e1bd..35ef7a110 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -93,14 +93,15 @@ def _build_bootstrap_version_manifest( def _build_existing_run_version_manifest( self, run: dict, + simulation: dict, version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict[str, str | None]: + fallback_manifest = self._build_bootstrap_version_manifest(simulation) version_manifest = { - "country_package_version": run.get("country_package_version"), - "policyengine_version": run.get("policyengine_version"), - "data_version": run.get("data_version"), - "runtime_app_name": run.get("runtime_app_name"), - "simulation_cache_version": run.get("simulation_cache_version"), + key: run.get(key) + if run.get(key) is not None + else fallback_manifest.get(key) + for key in fallback_manifest } return self._merge_version_manifest_overrides( version_manifest, @@ -129,6 +130,32 @@ def _list_simulation_runs_descending( runs.append(run) return runs + def _get_simulation_run_row( + self, + run_id: str, + *, + queryer=None, + simulation_id: int | None = None, + for_update: bool = False, + ) -> dict | None: + queryer = queryer or database + query = "SELECT * FROM simulation_runs WHERE id = ?" + params: list[str | int] = [run_id] + if simulation_id is not None: + query += " AND simulation_id = ?" + params.append(simulation_id) + if for_update: + query += self._lock_clause() + + row: Row | None = queryer.query(query, tuple(params)).fetchone() + if row is None: + return None + run = dict(row) + run["simulation_spec_snapshot_json"] = parse_json_field( + run.get("simulation_spec_snapshot_json") + ) + return run + def _select_mutable_run( self, simulation: dict, runs_descending: list[dict] ) -> dict | None: @@ -139,6 +166,23 @@ def _select_mutable_run( return run return runs_descending[0] if runs_descending else None + def _select_display_run( + self, simulation: dict, runs_descending: list[dict] + ) -> dict | None: + active_run_id = simulation.get("active_run_id") + if active_run_id is not None: + for run in runs_descending: + if run["id"] == active_run_id: + return run + + latest_successful_run_id = simulation.get("latest_successful_run_id") + if latest_successful_run_id is not None: + for run in runs_descending: + if run["id"] == latest_successful_run_id: + return run + + return runs_descending[0] if runs_descending else None + def _upsert_simulation_spec_in_transaction( self, tx, simulation: dict ) -> SimulationSpec: @@ -286,6 +330,69 @@ def _sync_parent_pointers_in_transaction( simulation["active_run_id"] = desired_active_run_id simulation["latest_successful_run_id"] = desired_latest_successful_run_id + def _sync_parent_mirror_from_display_run_in_transaction( + self, + tx, + simulation: dict, + runs_descending: list[dict], + ) -> dict: + self._sync_parent_pointers_in_transaction(tx, simulation, runs_descending) + display_run = self._select_display_run(simulation, runs_descending) + if display_run is None: + refreshed_simulation = self._get_simulation_row( + simulation["id"], + queryer=tx, + country_id=simulation["country_id"], + ) + if refreshed_simulation is None: + raise ValueError(f"Simulation #{simulation['id']} not found after sync") + return refreshed_simulation + + parent_api_version = ( + display_run["simulation_cache_version"] + if display_run.get("simulation_cache_version") is not None + else simulation.get("api_version") + ) + tx.query( + """ + UPDATE simulations + SET status = ?, output = ?, error_message = ?, api_version = ? + WHERE id = ? AND country_id = ? + """, + ( + display_run["status"], + serialize_json_field(display_run.get("output")), + display_run.get("error_message"), + parent_api_version, + simulation["id"], + simulation["country_id"], + ), + ) + refreshed_simulation = self._get_simulation_row( + simulation["id"], + queryer=tx, + country_id=simulation["country_id"], + ) + if refreshed_simulation is None: + raise ValueError(f"Simulation #{simulation['id']} not found after sync") + return refreshed_simulation + + def _merge_display_run_into_simulation( + self, + simulation: dict, + display_run: dict | None, + ) -> dict: + if display_run is None: + return dict(simulation) + + result = dict(simulation) + result["status"] = display_run["status"] + result["output"] = display_run.get("output") + result["error_message"] = display_run.get("error_message") + if display_run.get("simulation_cache_version") is not None: + result["api_version"] = display_run["simulation_cache_version"] + return result + def _ensure_simulation_dual_write_state_in_transaction( self, tx, @@ -326,6 +433,7 @@ def _ensure_simulation_dual_write_state_in_transaction( version_manifest = ( self._build_existing_run_version_manifest( mutable_run, + simulation, version_manifest_overrides=version_manifest_overrides, ) if mutable_run is not None @@ -512,7 +620,13 @@ def get_simulation(self, country_id: str, simulation_id: int) -> dict | None: f"Invalid simulation ID: {simulation_id}. Must be a positive integer." ) - return self._get_simulation_row(simulation_id, country_id=country_id) + simulation = self._get_simulation_row(simulation_id, country_id=country_id) + if simulation is None: + return None + + runs_descending = self._list_simulation_runs_descending(simulation_id) + display_run = self._select_display_run(simulation, runs_descending) + return self._merge_display_run_into_simulation(simulation, display_run) except Exception as e: print(f"Error fetching simulation #{simulation_id}. Details: {str(e)}") @@ -525,6 +639,7 @@ def update_simulation( status: str | None = None, output: str | None = None, error_message: str | None = None, + simulation_run_id: str | None = None, version_manifest_overrides: dict[str, str | None] | None = None, ) -> bool: """ @@ -541,40 +656,15 @@ def update_simulation( bool: True if update was successful. """ print(f"Updating simulation {simulation_id}") - api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) try: - update_fields = [] - update_values = [] - - if status is not None: - update_fields.append("status = ?") - update_values.append(status) - - if output is not None: - update_fields.append("output = ?") - update_values.append(output) - - if error_message is not None: - update_fields.append("error_message = ?") - update_values.append(error_message) - - # Only refresh api_version when the caller is actually - # changing one of the user-supplied fields above. The - # previous code appended api_version unconditionally, so - # the "no fields to update" guard below never fired and a - # PATCH with an empty body still touched the row. - if not update_fields and not version_manifest_overrides: + has_user_fields = ( + status is not None or output is not None or error_message is not None + ) + if not has_user_fields and not version_manifest_overrides: print("No fields to update") return False - # Metadata-only PATCHes update the run manifest, not the - # parent simulation row; only append api_version when - # caller-supplied parent fields are changing. - if update_fields: - update_fields.append("api_version = ?") - update_values.append(api_version) - def tx_callback(tx): simulation = self._get_simulation_row( simulation_id, @@ -585,17 +675,87 @@ def tx_callback(tx): if simulation is None: raise ValueError(f"Simulation #{simulation_id} not found") - if update_fields: - tx.query( - f"UPDATE simulations SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", - (*update_values, simulation_id, country_id), - ) - self._ensure_simulation_dual_write_state_in_transaction( + simulation_spec = self._upsert_simulation_spec_in_transaction( tx, + simulation, + ) + runs_descending = self._list_simulation_runs_descending( simulation_id, - country_id=country_id, + queryer=tx, + ) + if not runs_descending: + version_manifest = self._build_bootstrap_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) + self._insert_bootstrap_run( + tx, + simulation, + simulation_spec, + version_manifest=version_manifest, + ) + runs_descending = self._list_simulation_runs_descending( + simulation_id, + queryer=tx, + ) + + if simulation_run_id is not None: + mutable_run = self._get_simulation_run_row( + simulation_run_id, + queryer=tx, + simulation_id=simulation_id, + for_update=True, + ) + if mutable_run is None: + raise ValueError( + f"Simulation run #{simulation_run_id} not found for " + f"simulation #{simulation_id}" + ) + else: + mutable_run = self._select_mutable_run(simulation, runs_descending) + + if mutable_run is None: + raise ValueError( + "Cannot update simulation without an active simulation run" + ) + + run_update_state = dict(simulation) + run_update_state["status"] = ( + status if status is not None else mutable_run["status"] + ) + run_update_state["output"] = ( + output if output is not None else mutable_run.get("output") + ) + run_update_state["error_message"] = ( + error_message + if error_message is not None + else mutable_run.get("error_message") + ) + + version_manifest = self._build_existing_run_version_manifest( + mutable_run, + simulation, version_manifest_overrides=version_manifest_overrides, ) + self._update_simulation_run_in_transaction( + tx, + run_id=mutable_run["id"], + simulation=run_update_state, + simulation_spec=simulation_spec, + version_manifest=version_manifest, + ) + simulation["status"] = run_update_state["status"] + simulation["output"] = run_update_state["output"] + simulation["error_message"] = run_update_state["error_message"] + runs_descending = self._list_simulation_runs_descending( + simulation_id, + queryer=tx, + ) + self._sync_parent_mirror_from_display_run_in_transaction( + tx, + simulation, + runs_descending, + ) database.transaction(tx_callback) diff --git a/tests/unit/data/test_run_schema.py b/tests/unit/data/test_run_schema.py index 72b0ba6c0..5635220d3 100644 --- a/tests/unit/data/test_run_schema.py +++ b/tests/unit/data/test_run_schema.py @@ -63,8 +63,8 @@ def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): "simulation_cache_version", }.issubset(simulation_run_columns) - alias_columns = _column_names(test_db, "legacy_report_output_aliases") - assert {"legacy_report_output_id", "canonical_report_output_id"} == alias_columns + id_map_columns = _column_names(test_db, "legacy_report_output_id_map") + assert {"legacy_report_output_id", "canonical_report_output_id"} == id_map_columns def test_stage_one_schema_is_defined_in_both_sql_initializers(): @@ -76,7 +76,7 @@ def test_stage_one_schema_is_defined_in_both_sql_initializers(): required_snippets = [ "CREATE TABLE IF NOT EXISTS report_output_runs", "CREATE TABLE IF NOT EXISTS simulation_runs", - "CREATE TABLE IF NOT EXISTS legacy_report_output_aliases", + "CREATE TABLE IF NOT EXISTS legacy_report_output_id_map", "report_spec_json", "report_spec_status", "report_identity_hash", diff --git a/tests/unit/routes/test_route_exception_handling.py b/tests/unit/routes/test_route_exception_handling.py index b6fb38a28..124c8731b 100644 --- a/tests/unit/routes/test_route_exception_handling.py +++ b/tests/unit/routes/test_route_exception_handling.py @@ -63,7 +63,7 @@ def test_simulation_create_value_error_still_400(): def test_report_create_runtime_error_becomes_500(): client = _client_with(report_output_bp) with patch( - "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output", + "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output_for_create", side_effect=RuntimeError("db went away"), ): response = client.post( @@ -76,7 +76,7 @@ def test_report_create_runtime_error_becomes_500(): def test_report_create_value_error_still_400(): client = _client_with(report_output_bp) with patch( - "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output", + "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output_for_create", side_effect=ValueError("bad input"), ): response = client.post( diff --git a/tests/unit/services/test_report_output_alias_service.py b/tests/unit/services/test_report_output_id_map_service.py similarity index 83% rename from tests/unit/services/test_report_output_alias_service.py rename to tests/unit/services/test_report_output_id_map_service.py index d1d65002a..ce2a26326 100644 --- a/tests/unit/services/test_report_output_alias_service.py +++ b/tests/unit/services/test_report_output_id_map_service.py @@ -1,17 +1,17 @@ import pytest -from policyengine_api.services.report_output_alias_service import ( - ReportOutputAliasService, +from policyengine_api.services.report_output_id_map_service import ( + ReportOutputIdMapService, ) from policyengine_api.services.report_output_service import ReportOutputService from policyengine_api.services.simulation_service import SimulationService -alias_service = ReportOutputAliasService() +id_map_service = ReportOutputIdMapService() report_output_service = ReportOutputService() simulation_service = SimulationService() -class TestReportOutputAliasService: +class TestReportOutputIdMapService: def _insert_legacy_report_output( self, test_db, @@ -43,7 +43,7 @@ def _insert_legacy_report_output( ), ) - def test_resolves_to_canonical_report_output_id_when_alias_exists(self, test_db): + def test_resolves_to_canonical_report_output_id_when_mapping_exists(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_1", @@ -58,16 +58,16 @@ def test_resolves_to_canonical_report_output_id_when_alias_exists(self, test_db) ) self._insert_legacy_report_output(test_db, 999, canonical_report) - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=999, canonical_report_output_id=canonical_report["id"], ) - resolved_id = alias_service.resolve_canonical_report_output_id(999) + resolved_id = id_map_service.resolve_canonical_report_output_id(999) assert resolved_id == canonical_report["id"] - def test_returns_requested_id_when_alias_is_not_needed(self, test_db): + def test_returns_requested_id_when_mapping_is_not_needed(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_2", @@ -81,16 +81,16 @@ def test_returns_requested_id_when_alias_is_not_needed(self, test_db): year="2025", ) - resolved_id = alias_service.resolve_canonical_report_output_id( + resolved_id = id_map_service.resolve_canonical_report_output_id( report_output["id"] ) assert resolved_id == report_output["id"] def test_returns_none_for_unknown_report_output(self, test_db): - assert alias_service.resolve_canonical_report_output_id(123456) is None + assert id_map_service.resolve_canonical_report_output_id(123456) is None - def test_set_alias_is_idempotent_for_same_canonical_report_output(self, test_db): + def test_set_mapping_is_idempotent_for_same_canonical_report_output(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_3", @@ -106,21 +106,21 @@ def test_set_alias_is_idempotent_for_same_canonical_report_output(self, test_db) self._insert_legacy_report_output(test_db, 1001, canonical_report) assert ( - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=1001, canonical_report_output_id=canonical_report["id"], ) is True ) assert ( - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=1001, canonical_report_output_id=canonical_report["id"], ) is True ) - def test_rejects_alias_to_missing_canonical_report_output(self, test_db): + def test_rejects_mapping_to_missing_canonical_report_output(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_3a", @@ -136,14 +136,14 @@ def test_rejects_alias_to_missing_canonical_report_output(self, test_db): self._insert_legacy_report_output(test_db, 1002, canonical_report) with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=1002, canonical_report_output_id=999999, ) assert "Canonical report output #999999 not found" in str(exc_info.value) - def test_rejects_conflicting_alias_remap(self, test_db): + def test_rejects_conflicting_mapping_remap(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_4", @@ -163,23 +163,23 @@ def test_rejects_conflicting_alias_remap(self, test_db): year="2026", ) self._insert_legacy_report_output(test_db, 1003, canonical_report) - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=1003, canonical_report_output_id=canonical_report["id"], ) with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=1003, canonical_report_output_id=other_report["id"], ) assert ( - "Legacy report output alias already points to canonical report output " + "Legacy report output ID already maps to canonical report output " f"#{canonical_report['id']}" ) in str(exc_info.value) - def test_rejects_alias_when_legacy_report_output_is_missing(self, test_db): + def test_allows_mapping_when_legacy_report_output_is_missing(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_4a", @@ -193,15 +193,22 @@ def test_rejects_alias_when_legacy_report_output_is_missing(self, test_db): year="2025", ) - with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( + assert ( + id_map_service.set_mapping( legacy_report_output_id=10030, canonical_report_output_id=canonical_report["id"], ) + is True + ) - assert "Legacy report output #10030 not found" in str(exc_info.value) + resolved = id_map_service.resolve_report_output_id(10030) + assert resolved == { + "requested_report_output_id": 10030, + "canonical_report_output_id": canonical_report["id"], + "is_legacy_id": True, + } - def test_rejects_alias_when_reports_do_not_share_canonical_identity( + def test_rejects_mapping_when_reports_do_not_share_canonical_identity( self, test_db ): baseline_simulation = simulation_service.create_simulation( @@ -260,14 +267,14 @@ def test_rejects_alias_when_reports_do_not_share_canonical_identity( ) with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=distinct_report["id"], canonical_report_output_id=canonical_report["id"], ) assert "must share canonical report identity" in str(exc_info.value) - def test_rejects_alias_when_legacy_report_output_has_no_identity(self, test_db): + def test_rejects_mapping_when_legacy_report_output_has_no_identity(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_4b", @@ -297,14 +304,14 @@ def test_rejects_alias_when_legacy_report_output_has_no_identity(self, test_db): ) with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=10031, canonical_report_output_id=canonical_report["id"], ) assert "must have canonical report identity" in str(exc_info.value) - def test_rejects_alias_when_legacy_and_canonical_ids_match(self, test_db): + def test_rejects_mapping_when_legacy_and_canonical_ids_match(self, test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_4c", @@ -319,14 +326,14 @@ def test_rejects_alias_when_legacy_and_canonical_ids_match(self, test_db): ) with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=canonical_report["id"], canonical_report_output_id=canonical_report["id"], ) assert "must be different" in str(exc_info.value) - def test_rejects_alias_resolution_when_canonical_report_output_is_missing( + def test_rejects_mapping_resolution_when_canonical_report_output_is_missing( self, test_db ): simulation = simulation_service.create_simulation( @@ -342,7 +349,7 @@ def test_rejects_alias_resolution_when_canonical_report_output_is_missing( year="2025", ) self._insert_legacy_report_output(test_db, 1004, canonical_report) - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=1004, canonical_report_output_id=canonical_report["id"], ) @@ -352,8 +359,9 @@ def test_rejects_alias_resolution_when_canonical_report_output_is_missing( ) with pytest.raises(ValueError) as exc_info: - alias_service.resolve_canonical_report_output_id(1004) + id_map_service.resolve_canonical_report_output_id(1004) assert ( - f"Alias points to missing canonical report output #{canonical_report['id']}" + "Legacy ID mapping points to missing canonical report output " + f"#{canonical_report['id']}" ) in str(exc_info.value) diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 42e3c88e0..312a295b8 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -3,8 +3,8 @@ from datetime import datetime, timezone from policyengine_api.constants import get_report_output_cache_version -from policyengine_api.services.report_output_alias_service import ( - ReportOutputAliasService, +from policyengine_api.services.report_output_id_map_service import ( + ReportOutputIdMapService, ) from policyengine_api.services.report_output_service import ReportOutputService from policyengine_api.services.report_run_service import ReportRunService @@ -18,7 +18,7 @@ service = ReportOutputService() report_run_service = ReportRunService() simulation_service = SimulationService() -alias_service = ReportOutputAliasService() +id_map_service = ReportOutputIdMapService() class TestReportOutputRunTimestamps: @@ -625,6 +625,111 @@ def test_create_report_output_distinguishes_explicit_economy_specs_by_identity( != stored_reports[1]["report_identity_hash"] ) + def test_create_report_output_loads_exact_inserted_row_for_explicit_spec( + self, test_db, monkeypatch + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ks", + population_type="geography", + policy_id=39, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ks", + population_type="geography", + policy_id=40, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ks", + "baseline_policy_id": 39, + "reform_policy_id": 40, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + + def fail_legacy_key_lookup(**_kwargs): + raise AssertionError("create should load the inserted row by primary key") + + monkeypatch.setattr( + service, + "_find_existing_report_output_row", + fail_legacy_key_lookup, + ) + + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + assert created_report["simulation_1_id"] == baseline_simulation["id"] + assert created_report["simulation_2_id"] == reform_simulation["id"] + assert created_report["report_identity_hash"] is not None + + def test_find_existing_for_create_validates_explicit_spec_context_before_reuse( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ia", + population_type="geography", + policy_id=36, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ia", + population_type="geography", + policy_id=37, + ) + mismatched_baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ia", + population_type="geography", + policy_id=38, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ia", + "baseline_policy_id": 36, + "reform_policy_id": 37, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + with pytest.raises( + ValueError, match="Report spec baseline_policy_id must match" + ): + service.find_existing_report_output_for_create( + country_id="us", + simulation_1_id=mismatched_baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + ) + class TestGetReportOutput: """Test retrieving report outputs from the database.""" @@ -1359,7 +1464,7 @@ def test_get_report_output_uses_selected_display_run_for_canonical_parent( assert result["output"] == json.dumps({"budget": {"budgetary_impact": 2}}) assert result["api_version"] == get_report_output_cache_version("us") - def test_get_report_output_resolves_alias_to_canonical_parent_and_display_run( + def test_get_report_output_resolves_legacy_id_to_canonical_display_run( self, test_db ): simulation = simulation_service.create_simulation( @@ -1400,7 +1505,7 @@ def test_get_report_output_resolves_alias_to_canonical_parent_and_display_run( canonical_report["report_identity_schema_version"], ), ) - alias_service.set_alias( + id_map_service.set_mapping( legacy_report_output_id=999, canonical_report_output_id=canonical_report["id"], ) @@ -2079,7 +2184,7 @@ def fail_dual_write(tx, report_output_id, *, country_id=None, **kwargs): ).fetchall() assert rows == [] - def test_update_report_output_rolls_back_parent_update_on_dual_write_failure( + def test_update_report_output_rolls_back_parent_update_on_run_write_failure( self, test_db, monkeypatch ): simulation = simulation_service.create_simulation( @@ -2095,16 +2200,16 @@ def test_update_report_output_rolls_back_parent_update_on_dual_write_failure( year="2025", ) - def fail_dual_write(tx, report_output_id, *, country_id=None, **kwargs): - raise RuntimeError("dual write sync failed") + def fail_run_update(*args, **kwargs): + raise RuntimeError("run update failed") monkeypatch.setattr( service, - "_ensure_report_output_dual_write_state_in_transaction", - fail_dual_write, + "_update_report_run_in_transaction", + fail_run_update, ) - with pytest.raises(RuntimeError, match="dual write sync failed"): + with pytest.raises(RuntimeError, match="run update failed"): service.update_report_output( country_id="us", report_id=created_report["id"], diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index 8117488f7..fc53504dc 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -444,7 +444,7 @@ def test_update_simulation_does_not_append_extra_run_for_legacy_patch_traffic( assert runs[0]["id"] == first_run["id"] assert runs[0]["status"] == "complete" - def test_update_simulation_rolls_back_parent_update_on_dual_write_failure( + def test_update_simulation_rolls_back_parent_update_on_run_write_failure( self, test_db, monkeypatch ): created_simulation = service.create_simulation( @@ -454,16 +454,16 @@ def test_update_simulation_rolls_back_parent_update_on_dual_write_failure( policy_id=15, ) - def fail_dual_write(tx, simulation_id, *, country_id=None, **kwargs): - raise RuntimeError("dual write sync failed") + def fail_run_update(*args, **kwargs): + raise RuntimeError("run update failed") monkeypatch.setattr( service, - "_ensure_simulation_dual_write_state_in_transaction", - fail_dual_write, + "_update_simulation_run_in_transaction", + fail_run_update, ) - with pytest.raises(RuntimeError, match="dual write sync failed"): + with pytest.raises(RuntimeError, match="run update failed"): service.update_simulation( country_id="us", simulation_id=created_simulation["id"], @@ -562,6 +562,37 @@ def test_update_simulation_preserves_existing_run_metadata_without_overrides( assert run["data_version"] == "2026.04.16" assert run["runtime_app_name"] == "policyengine-app-v2" + def test_update_simulation_backfills_null_existing_run_metadata_from_parent( + self, test_db + ): + created_simulation = service.create_simulation( + country_id="us", + population_id="household_metadata_null_backfill", + population_type="household", + policy_id=17, + ) + test_db.query( + """ + UPDATE simulation_runs + SET country_package_version = NULL + WHERE simulation_id = ? + """, + (created_simulation["id"],), + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="complete", + output=json.dumps({"result": "ok"}), + ) + + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (created_simulation["id"],), + ).fetchone() + assert run["country_package_version"] == created_simulation["api_version"] + def test_update_simulation_allows_explicit_metadata_override_on_existing_run( self, test_db ): @@ -569,7 +600,7 @@ def test_update_simulation_allows_explicit_metadata_override_on_existing_run( country_id="us", population_id="household_metadata_override", population_type="household", - policy_id=17, + policy_id=18, ) service.update_simulation( diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index 34fac32f3..fdbe5f4b9 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -5,14 +5,20 @@ from policyengine_api.constants import get_report_output_cache_version from policyengine_api.routes.report_output_routes import report_output_bp from policyengine_api.routes.simulation_routes import simulation_bp +from policyengine_api.services.report_output_id_map_service import ( + ReportOutputIdMapService, +) from policyengine_api.services.report_output_service import ReportOutputService from policyengine_api.services.report_run_service import ReportRunService +from policyengine_api.services.simulation_run_service import SimulationRunService from policyengine_api.services.simulation_service import SimulationService simulation_service = SimulationService() report_output_service = ReportOutputService() report_run_service = ReportRunService() +report_output_id_map_service = ReportOutputIdMapService() +simulation_run_service = SimulationRunService() def create_test_client() -> Flask: @@ -279,6 +285,40 @@ def test_create_report_output_same_explicit_spec_returns_existing_row(test_db): ) +def test_create_report_output_same_identity_after_cache_version_change_reuses_row( + test_db, monkeypatch +): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_cache_version_reuse", + population_type="household", + policy_id=75, + ) + client = create_test_client() + payload = { + "simulation_1_id": simulation["id"], + "simulation_2_id": None, + "year": "2026", + } + + first_response = client.post("/us/report", json=payload) + monkeypatch.setattr( + "policyengine_api.services.report_output_service.get_report_output_cache_version", + lambda country_id: f"{country_id}-new-report-cache-version", + ) + second_response = client.post("/us/report", json=payload) + + assert first_response.status_code == 201 + assert second_response.status_code == 200 + assert ( + first_response.get_json()["result"]["id"] + == second_response.get_json()["result"]["id"] + ) + + rows = test_db.query("SELECT * FROM report_outputs").fetchall() + assert len(rows) == 1 + + def test_create_report_output_distinct_explicit_specs_create_distinct_rows(test_db): baseline_simulation = simulation_service.create_simulation( country_id="us", @@ -343,6 +383,70 @@ def test_create_report_output_distinct_explicit_specs_create_distinct_rows(test_ ) +def test_create_report_output_explicit_spec_validates_requested_simulations_before_reuse( + test_db, +): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ma", + population_type="geography", + policy_id=70, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ma", + population_type="geography", + policy_id=71, + ) + mismatched_baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ma", + population_type="geography", + policy_id=72, + ) + payload = { + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ma", + "baseline_policy_id": 70, + "reform_policy_id": 71, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + } + + client = create_test_client() + create_response = client.post("/us/report", json=payload) + missing_response = client.post( + "/us/report", + json={ + **payload, + "simulation_1_id": 999999, + }, + ) + mismatched_response = client.post( + "/us/report", + json={ + **payload, + "simulation_1_id": mismatched_baseline_simulation["id"], + }, + ) + + assert create_response.status_code == 201 + assert missing_response.status_code == 400 + assert mismatched_response.status_code == 400 + + report_rows = test_db.query("SELECT * FROM report_outputs").fetchall() + assert len(report_rows) == 1 + + def test_create_report_output_missing_primary_simulation_returns_bad_request(test_db): client = create_test_client() response = client.post( @@ -467,6 +571,72 @@ def test_patch_simulation_persists_run_metadata_fields(test_db): assert run["runtime_app_name"] == "policyengine-app-v2" +def test_patch_simulation_explicit_run_id_updates_only_that_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_explicit_simulation_run", + population_type="household", + policy_id=78, + ) + simulation_service.update_simulation( + country_id="us", + simulation_id=simulation["id"], + status="complete", + output=json.dumps({"result": "initial"}), + ) + initial_run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (simulation["id"],), + ).fetchone() + rerun = simulation_run_service.create_simulation_run( + simulation["id"], + trigger_type="rerun", + ) + + client = create_test_client() + response = client.patch( + "/us/simulation", + json={ + "id": simulation["id"], + "simulation_run_id": rerun["id"], + "status": "complete", + "output": json.dumps({"result": "explicit rerun"}), + }, + ) + + assert response.status_code == 200 + initial_run_after = test_db.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (initial_run["id"],), + ).fetchone() + rerun_after = test_db.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (rerun["id"],), + ).fetchone() + assert initial_run_after["output"] == json.dumps({"result": "initial"}) + assert rerun_after["output"] == json.dumps({"result": "explicit rerun"}) + + +def test_patch_simulation_rejects_non_string_run_metadata(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_invalid_metadata", + population_type="household", + policy_id=73, + ) + + client = create_test_client() + response = client.patch( + "/us/simulation", + json={ + "id": simulation["id"], + "country_package_version": 123, + }, + ) + + assert response.status_code == 400 + + def test_get_report_output_wrong_country_returns_not_found(test_db): test_db.query( """ @@ -486,7 +656,7 @@ def test_get_report_output_wrong_country_returns_not_found(test_db): assert response.status_code == 404 -def test_get_report_output_alias_resolves_to_canonical_display_run(test_db): +def test_get_report_output_legacy_id_resolves_to_canonical_display_run(test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_route_alias", @@ -524,7 +694,7 @@ def test_get_report_output_alias_resolves_to_canonical_display_run(test_db): ) test_db.query( """ - INSERT INTO legacy_report_output_aliases ( + INSERT INTO legacy_report_output_id_map ( legacy_report_output_id, canonical_report_output_id ) VALUES (?, ?) """, @@ -584,9 +754,7 @@ def test_get_report_output_reads_malformed_legacy_row_without_runs_or_identity( payload = response.get_json() assert payload["result"]["id"] == malformed_report["id"] assert payload["result"]["status"] == "error" - assert payload["result"]["output"] == json.dumps( - {"result": "legacy-malformed"} - ) + assert payload["result"]["output"] == json.dumps({"result": "legacy-malformed"}) assert payload["result"]["api_version"] == "r0legacy-malformed" @@ -895,6 +1063,124 @@ def test_patch_report_output_complete_promotes_active_rerun_route_path(test_db): assert stored_report["latest_successful_run_id"] == rerun["id"] +def test_patch_report_output_explicit_run_id_updates_only_that_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_explicit_report_run", + population_type="household", + policy_id=76, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"result": "initial"}), + ) + initial_run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchone() + rerun = report_run_service.create_report_output_run( + report["id"], trigger_type="rerun" + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report["id"], + "report_output_run_id": rerun["id"], + "status": "complete", + "output": json.dumps({"result": "explicit rerun"}), + }, + ) + + assert response.status_code == 200 + initial_run_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (initial_run["id"],), + ).fetchone() + rerun_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (rerun["id"],), + ).fetchone() + assert initial_run_after["output"] == json.dumps({"result": "initial"}) + assert rerun_after["output"] == json.dumps({"result": "explicit rerun"}) + + +def test_create_report_rerun_via_legacy_id_creates_canonical_linked_runs(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_legacy_rerun", + population_type="household", + policy_id=77, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "canonical"}), + ) + test_db.query( + """ + INSERT INTO report_outputs ( + id, country_id, simulation_1_id, simulation_2_id, api_version, status, year, + report_identity_hash, report_identity_schema_version + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 3001, + "us", + simulation["id"], + None, + "legacy-report-cache-version", + "complete", + "2026", + canonical_report["report_identity_hash"], + canonical_report["report_identity_schema_version"], + ), + ) + report_output_id_map_service.set_mapping( + legacy_report_output_id=3001, + canonical_report_output_id=canonical_report["id"], + ) + + client = create_test_client() + response = client.post("/us/report/3001/rerun", json={}) + + assert response.status_code == 201 + result = response.get_json()["result"] + assert result["requested_report_output_id"] == 3001 + assert result["report_output_id"] == canonical_report["id"] + assert len(result["simulation_run_ids"]) == 1 + + report_run = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (result["report_output_run_id"],), + ).fetchone() + assert report_run["report_output_id"] == canonical_report["id"] + assert report_run["trigger_type"] == "rerun" + + simulation_run = test_db.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (result["simulation_run_ids"][0],), + ).fetchone() + assert simulation_run["report_output_run_id"] == result["report_output_run_id"] + assert simulation_run["input_position"] == 1 + + def test_patch_report_output_persists_run_metadata_fields(test_db): simulation = simulation_service.create_simulation( country_id="us", @@ -936,6 +1222,32 @@ def test_patch_report_output_persists_run_metadata_fields(test_db): assert run["resolved_dataset"] == "enhanced_us_household" +def test_patch_report_output_rejects_non_string_run_metadata(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/mt", + population_type="geography", + policy_id=74, + ) + report_output = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report_output["id"], + "policyengine_version": 123, + }, + ) + + assert response.status_code == 400 + + def test_patch_report_output_preserves_stored_explicit_report_spec(test_db): baseline_simulation = simulation_service.create_simulation( country_id="us", From c6281b526660717e376955e5846e2425ee8cd97f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 16:40:22 +0200 Subject: [PATCH 11/17] Add report rerun route coverage --- .../unit/services/test_report_spec_service.py | 36 +-- tests/unit/test_stage5_routes.py | 299 ++++++++++++++++++ 2 files changed, 317 insertions(+), 18 deletions(-) diff --git a/tests/unit/services/test_report_spec_service.py b/tests/unit/services/test_report_spec_service.py index 0dd98db86..27d0a22fa 100644 --- a/tests/unit/services/test_report_spec_service.py +++ b/tests/unit/services/test_report_spec_service.py @@ -90,12 +90,12 @@ def test_raises_for_mixed_population_types(self, test_db): population_type="geography", policy_id=2, ) - report_output = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation_1["id"], - simulation_2_id=simulation_2["id"], - year="2025", - ) + report_output = { + "country_id": "us", + "simulation_1_id": simulation_1["id"], + "simulation_2_id": simulation_2["id"], + "year": "2025", + } with pytest.raises(ValueError) as exc_info: report_spec_service.build_report_spec( @@ -117,12 +117,12 @@ def test_raises_for_mismatched_household_ids(self, test_db): population_type="household", policy_id=2, ) - report_output = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation_1["id"], - simulation_2_id=simulation_2["id"], - year="2025", - ) + report_output = { + "country_id": "us", + "simulation_1_id": simulation_1["id"], + "simulation_2_id": simulation_2["id"], + "year": "2025", + } with pytest.raises(ValueError) as exc_info: report_spec_service.build_report_spec( @@ -144,12 +144,12 @@ def test_raises_for_mismatched_geography_ids(self, test_db): population_type="geography", policy_id=11, ) - report_output = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation_1["id"], - simulation_2_id=simulation_2["id"], - year="2027", - ) + report_output = { + "country_id": "us", + "simulation_1_id": simulation_1["id"], + "simulation_2_id": simulation_2["id"], + "year": "2027", + } with pytest.raises(ValueError) as exc_info: report_spec_service.build_report_spec( diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index fdbe5f4b9..99636a242 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -1114,6 +1114,137 @@ def test_patch_report_output_explicit_run_id_updates_only_that_run(test_db): assert rerun_after["output"] == json.dumps({"result": "explicit rerun"}) +def test_patch_report_output_explicit_run_id_through_legacy_id_updates_canonical_run( + test_db, +): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_legacy_explicit_report_run", + population_type="household", + policy_id=79, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "initial"}), + ) + initial_run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (canonical_report["id"],), + ).fetchone() + rerun = report_run_service.create_report_output_run( + canonical_report["id"], trigger_type="rerun" + ) + test_db.query( + """ + INSERT INTO report_outputs ( + id, country_id, simulation_1_id, simulation_2_id, api_version, status, year, + report_identity_hash, report_identity_schema_version + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + 3002, + "us", + simulation["id"], + None, + "legacy-report-cache-version", + "complete", + "2026", + canonical_report["report_identity_hash"], + canonical_report["report_identity_schema_version"], + ), + ) + report_output_id_map_service.set_mapping( + legacy_report_output_id=3002, + canonical_report_output_id=canonical_report["id"], + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": 3002, + "report_output_run_id": rerun["id"], + "status": "complete", + "output": json.dumps({"result": "legacy explicit rerun"}), + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == 3002 + + initial_run_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (initial_run["id"],), + ).fetchone() + rerun_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (rerun["id"],), + ).fetchone() + assert initial_run_after["output"] == json.dumps({"result": "initial"}) + assert rerun_after["report_output_id"] == canonical_report["id"] + assert rerun_after["output"] == json.dumps({"result": "legacy explicit rerun"}) + + +def test_create_report_rerun_via_canonical_id_creates_canonical_linked_runs(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_canonical_rerun", + population_type="household", + policy_id=80, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "canonical"}), + ) + + client = create_test_client() + response = client.post(f"/us/report/{canonical_report['id']}/rerun", json={}) + + assert response.status_code == 201 + result = response.get_json()["result"] + assert result["requested_report_output_id"] == canonical_report["id"] + assert result["report_output_id"] == canonical_report["id"] + assert len(result["simulation_run_ids"]) == 1 + + report_runs = test_db.query( + """ + SELECT * FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence + """, + (canonical_report["id"],), + ).fetchall() + assert len(report_runs) == 2 + assert report_runs[0]["trigger_type"] == "initial" + assert report_runs[1]["id"] == result["report_output_run_id"] + assert report_runs[1]["trigger_type"] == "rerun" + assert report_runs[1]["status"] == "pending" + + simulation_run = test_db.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (result["simulation_run_ids"][0],), + ).fetchone() + assert simulation_run["report_output_run_id"] == result["report_output_run_id"] + assert simulation_run["input_position"] == 1 + + def test_create_report_rerun_via_legacy_id_creates_canonical_linked_runs(test_db): simulation = simulation_service.create_simulation( country_id="us", @@ -1181,6 +1312,174 @@ def test_create_report_rerun_via_legacy_id_creates_canonical_linked_runs(test_db assert simulation_run["input_position"] == 1 +def test_create_report_rerun_for_comparison_report_creates_two_linked_simulation_runs( + test_db, +): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nc", + population_type="geography", + policy_id=81, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nc", + population_type="geography", + policy_id=82, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"result": "comparison"}), + ) + + client = create_test_client() + response = client.post(f"/us/report/{report['id']}/rerun", json={}) + + assert response.status_code == 201 + result = response.get_json()["result"] + assert result["report_output_id"] == report["id"] + assert len(result["simulation_run_ids"]) == 2 + + linked_simulation_runs = test_db.query( + """ + SELECT * FROM simulation_runs + WHERE report_output_run_id = ? + ORDER BY input_position + """, + (result["report_output_run_id"],), + ).fetchall() + assert [run["simulation_id"] for run in linked_simulation_runs] == [ + baseline_simulation["id"], + reform_simulation["id"], + ] + assert [run["input_position"] for run in linked_simulation_runs] == [1, 2] + assert [run["status"] for run in linked_simulation_runs] == [ + "pending", + "pending", + ] + + +def test_create_report_rerun_rejects_report_with_missing_linked_simulation(test_db): + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, year + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ("us", 987654, None, get_report_output_cache_version("us"), "complete", "2026"), + ) + report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + client = create_test_client() + response = client.post(f"/us/report/{report['id']}/rerun", json={}) + + assert response.status_code == 400 + assert "Simulation #987654 not found" in response.get_data(as_text=True) + + report_runs = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchall() + assert report_runs == [] + + +def test_report_rerun_http_lifecycle_patches_linked_runs_and_reads_display( + test_db, +): + client = create_test_client() + simulation_response = client.post( + "/us/simulation", + json={ + "population_id": "household_route_http_lifecycle", + "population_type": "household", + "policy_id": 83, + }, + ) + assert simulation_response.status_code == 201 + simulation = simulation_response.get_json()["result"] + + report_response = client.post( + "/us/report", + json={ + "simulation_1_id": simulation["id"], + "simulation_2_id": None, + "year": "2026", + }, + ) + assert report_response.status_code == 201 + report = report_response.get_json()["result"] + + initial_patch_response = client.patch( + "/us/report", + json={ + "id": report["id"], + "status": "complete", + "output": json.dumps({"result": "initial report"}), + }, + ) + assert initial_patch_response.status_code == 200 + + rerun_response = client.post(f"/us/report/{report['id']}/rerun", json={}) + assert rerun_response.status_code == 201 + rerun = rerun_response.get_json()["result"] + assert len(rerun["simulation_run_ids"]) == 1 + + simulation_patch_response = client.patch( + "/us/simulation", + json={ + "id": simulation["id"], + "simulation_run_id": rerun["simulation_run_ids"][0], + "status": "complete", + "output": json.dumps({"result": "rerun simulation"}), + }, + ) + assert simulation_patch_response.status_code == 200 + + report_patch_response = client.patch( + "/us/report", + json={ + "id": report["id"], + "report_output_run_id": rerun["report_output_run_id"], + "status": "complete", + "output": json.dumps({"result": "rerun report"}), + }, + ) + assert report_patch_response.status_code == 200 + + get_response = client.get(f"/us/report/{report['id']}") + assert get_response.status_code == 200 + result = get_response.get_json()["result"] + assert result["id"] == report["id"] + assert result["status"] == "complete" + assert result["output"] == json.dumps({"result": "rerun report"}) + + report_rows = test_db.query("SELECT * FROM report_outputs").fetchall() + report_runs = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchall() + linked_simulation_runs = test_db.query( + """ + SELECT * FROM simulation_runs + WHERE report_output_run_id = ? + """, + (rerun["report_output_run_id"],), + ).fetchall() + assert len(report_rows) == 1 + assert len(report_runs) == 2 + assert len(linked_simulation_runs) == 1 + + def test_patch_report_output_persists_run_metadata_fields(test_db): simulation = simulation_service.create_simulation( country_id="us", From 33c36a3a2e06b715507e94ae9a85639315e431cf Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 19:43:53 +0200 Subject: [PATCH 12/17] Refactor report rerun orchestration services --- .../services/report_output_service.py | 279 ++---------------- .../services/report_run_service.py | 141 +++++---- .../services/simulation_run_service.py | 147 +++++---- .../services/simulation_service.py | 65 ++++ 4 files changed, 272 insertions(+), 360 deletions(-) diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 246ec764c..1372cdb13 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -1,4 +1,3 @@ -import uuid from datetime import datetime, timezone from sqlalchemy.engine.row import Row @@ -640,205 +639,18 @@ def _insert_bootstrap_report_run( report_spec: ReportSpec | None, version_manifest: dict[str, str | None], ) -> None: - requested_at = self._utc_timestamp() - is_terminal = report_output["status"] in ("complete", "error") - has_started = report_output["status"] in ("running", "complete", "error") - started_at = requested_at if has_started else None - finished_at = requested_at if is_terminal else None - - tx.query( - """ - INSERT INTO report_output_runs ( - id, report_output_id, run_sequence, status, output, error_message, - trigger_type, requested_at, started_at, finished_at, source_run_id, - report_spec_snapshot_json, country_package_version, policyengine_version, - data_version, runtime_app_name, report_cache_version, - simulation_cache_version, requested_version_override, resolved_dataset, - resolved_options_hash - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - str(uuid.uuid4()), - report_output["id"], - 1, - report_output["status"], - serialize_json_field(report_output.get("output")), - report_output.get("error_message"), - "initial", - requested_at, - started_at, - finished_at, - None, - (report_spec.model_dump_json() if report_spec is not None else None), - version_manifest["country_package_version"], - version_manifest["policyengine_version"], - version_manifest["data_version"], - version_manifest["runtime_app_name"], - version_manifest["report_cache_version"], - version_manifest["simulation_cache_version"], - version_manifest["requested_version_override"], - version_manifest["resolved_dataset"], - version_manifest["resolved_options_hash"], - ), - ) - - def _insert_report_run_in_transaction( - self, - tx, - report_output: dict, - *, - status: str, - trigger_type: str, - source_run_id: str | None, - report_spec: ReportSpec | None, - version_manifest: dict[str, str | None], - ) -> str: - run_sequence_row: Row | None = tx.query( - """ - SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence - FROM report_output_runs - WHERE report_output_id = ? - """, - (report_output["id"],), - ).fetchone() - run_sequence = ( - int(run_sequence_row["max_run_sequence"]) + 1 - if run_sequence_row is not None - else 1 - ) - run_id = str(uuid.uuid4()) - requested_at = self._utc_timestamp() - is_terminal = status in ("complete", "error") - has_started = status in ("running", "complete", "error") - started_at = requested_at if has_started else None - finished_at = requested_at if is_terminal else None - - tx.query( - """ - INSERT INTO report_output_runs ( - id, report_output_id, run_sequence, status, output, error_message, - trigger_type, requested_at, started_at, finished_at, source_run_id, - report_spec_snapshot_json, country_package_version, policyengine_version, - data_version, runtime_app_name, report_cache_version, - simulation_cache_version, requested_version_override, resolved_dataset, - resolved_options_hash - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - run_id, - report_output["id"], - run_sequence, - status, - None, - None, - trigger_type, - requested_at, - started_at, - finished_at, - source_run_id, - (report_spec.model_dump_json() if report_spec is not None else None), - version_manifest["country_package_version"], - version_manifest["policyengine_version"], - version_manifest["data_version"], - version_manifest["runtime_app_name"], - version_manifest["report_cache_version"], - version_manifest["simulation_cache_version"], - version_manifest["requested_version_override"], - version_manifest["resolved_dataset"], - version_manifest["resolved_options_hash"], - ), - ) - return run_id - - def _select_simulation_display_run( - self, - simulation: dict, - runs_descending: list[dict], - ) -> dict | None: - active_run_id = simulation.get("active_run_id") - if active_run_id is not None: - for run in runs_descending: - if run["id"] == active_run_id: - return run - - latest_successful_run_id = simulation.get("latest_successful_run_id") - if latest_successful_run_id is not None: - for run in runs_descending: - if run["id"] == latest_successful_run_id: - return run - - return runs_descending[0] if runs_descending else None - - def _insert_simulation_run_in_transaction( - self, - tx, - simulation: dict, - *, - report_output_run_id: str, - input_position: int, - source_run: dict | None, - ) -> str: - simulation_spec = ( - self.simulation_service._upsert_simulation_spec_in_transaction( - tx, - simulation, - ) - ) - version_manifest = ( - self.simulation_service._build_existing_run_version_manifest( - source_run, - simulation, - ) - if source_run is not None - else self.simulation_service._build_bootstrap_version_manifest(simulation) - ) - run_sequence_row: Row | None = tx.query( - """ - SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence - FROM simulation_runs - WHERE simulation_id = ? - """, - (simulation["id"],), - ).fetchone() - run_sequence = ( - int(run_sequence_row["max_run_sequence"]) + 1 - if run_sequence_row is not None - else 1 - ) - run_id = str(uuid.uuid4()) - tx.query( - """ - INSERT INTO simulation_runs ( - id, simulation_id, report_output_run_id, input_position, run_sequence, - status, output, error_message, trigger_type, requested_at, started_at, - finished_at, source_run_id, simulation_spec_snapshot_json, - country_package_version, policyengine_version, data_version, - runtime_app_name, simulation_cache_version - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - run_id, - simulation["id"], - report_output_run_id, - input_position, - run_sequence, - "pending", - None, - None, - "report_rerun", - self._utc_timestamp(), - None, - None, - source_run["id"] if source_run is not None else None, - simulation_spec.model_dump_json(), - version_manifest["country_package_version"], - version_manifest["policyengine_version"], - version_manifest["data_version"], - version_manifest["runtime_app_name"], - version_manifest["simulation_cache_version"], + self.report_run_service.create_report_output_run_in_transaction( + tx, + report_output["id"], + status=report_output["status"], + trigger_type="initial", + output=report_output.get("output"), + error_message=report_output.get("error_message"), + report_spec_snapshot=( + report_spec.model_dump() if report_spec is not None else None ), + version_manifest=version_manifest, ) - return run_id def _update_report_run_in_transaction( self, @@ -1752,17 +1564,24 @@ def tx_callback(tx): version_manifest_overrides=version_manifest_overrides, ) ) - report_run_id = self._insert_report_run_in_transaction( - tx, - canonical_report, - status="pending", - trigger_type="rerun", - source_run_id=( - source_report_run["id"] if source_report_run is not None else None - ), - report_spec=report_spec, - version_manifest=report_version_manifest, + report_run = ( + self.report_run_service.create_report_output_run_in_transaction( + tx, + canonical_report_id, + status="pending", + trigger_type="rerun", + source_run_id=( + source_report_run["id"] + if source_report_run is not None + else None + ), + report_spec_snapshot=( + report_spec.model_dump() if report_spec is not None else None + ), + version_manifest=report_version_manifest, + ) ) + report_run_id = report_run["id"] simulation_run_ids: list[str] = [] for input_position, simulation in ( @@ -1772,53 +1591,13 @@ def tx_callback(tx): if simulation is None: continue - simulation_runs_descending = ( - self.simulation_service._list_simulation_runs_descending( - simulation["id"], - queryer=tx, - ) - ) - source_simulation_run = self._select_simulation_display_run( - simulation, - simulation_runs_descending, - ) - simulation_run_id = self._insert_simulation_run_in_transaction( + simulation_run = self.simulation_service.create_report_rerun_simulation_run_in_transaction( tx, simulation, report_output_run_id=report_run_id, input_position=input_position, - source_run=source_simulation_run, - ) - simulation_run_ids.append(simulation_run_id) - - simulation["status"] = "pending" - simulation["output"] = None - simulation["error_message"] = None - simulation_runs_descending = ( - self.simulation_service._list_simulation_runs_descending( - simulation["id"], - queryer=tx, - ) - ) - self.simulation_service._sync_parent_pointers_in_transaction( - tx, - simulation, - simulation_runs_descending, - ) - tx.query( - """ - UPDATE simulations - SET status = ?, output = ?, error_message = ? - WHERE id = ? AND country_id = ? - """, - ( - "pending", - None, - None, - simulation["id"], - country_id, - ), ) + simulation_run_ids.append(simulation_run["id"]) canonical_report["status"] = "pending" canonical_report["output"] = None diff --git a/policyengine_api/services/report_run_service.py b/policyengine_api/services/report_run_service.py index 9899f6cc9..74fc61e96 100644 --- a/policyengine_api/services/report_run_service.py +++ b/policyengine_api/services/report_run_service.py @@ -55,69 +55,98 @@ def create_report_output_run( report_spec_snapshot: dict[str, Any] | str | None = None, version_manifest: dict[str, str | None] | None = None, run_id: str | None = None, + ) -> dict: + def create_run_transaction(tx) -> dict: + return self.create_report_output_run_in_transaction( + tx, + report_output_id, + status=status, + trigger_type=trigger_type, + output=output, + error_message=error_message, + source_run_id=source_run_id, + report_spec_snapshot=report_spec_snapshot, + version_manifest=version_manifest, + run_id=run_id, + ) + + return database.transaction(create_run_transaction) + + def create_report_output_run_in_transaction( + self, + tx, + report_output_id: int, + status: str = "pending", + trigger_type: str = "initial", + output: dict[str, Any] | list[Any] | str | None = None, + error_message: str | None = None, + source_run_id: str | None = None, + report_spec_snapshot: dict[str, Any] | str | None = None, + version_manifest: dict[str, str | None] | None = None, + run_id: str | None = None, ) -> dict: run_id = run_id or str(uuid.uuid4()) version_manifest = version_manifest or {} lock_clause = "" if database.local else " FOR UPDATE" - def create_run_transaction(tx) -> None: - parent_row: Row | None = tx.query( - f"SELECT id FROM report_outputs WHERE id = ?{lock_clause}", - (report_output_id,), - ).fetchone() - if parent_row is None: - raise ValueError(f"Report output #{report_output_id} not found") - - run_sequence_row: Row | None = tx.query( - """ - SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence - FROM report_output_runs - WHERE report_output_id = ? - """, - (report_output_id,), - ).fetchone() - run_sequence = ( - int(run_sequence_row["max_run_sequence"]) + 1 - if run_sequence_row is not None - else 1 - ) + parent_row: Row | None = tx.query( + f"SELECT id FROM report_outputs WHERE id = ?{lock_clause}", + (report_output_id,), + ).fetchone() + if parent_row is None: + raise ValueError(f"Report output #{report_output_id} not found") - requested_at = self._utc_timestamp() - is_terminal = status in ("complete", "error") - has_started = status in ("running", "complete", "error") - started_at = requested_at if has_started else None - finished_at = requested_at if is_terminal else None - - tx.query( - f""" - INSERT INTO report_output_runs ( - id, report_output_id, run_sequence, status, output, error_message, - trigger_type, requested_at, started_at, finished_at, source_run_id, - report_spec_snapshot_json, {", ".join(REPORT_RUN_VERSION_FIELDS)} - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - run_id, - report_output_id, - run_sequence, - status, - self._serialize_json(output), - error_message, - trigger_type, - requested_at, - started_at, - finished_at, - source_run_id, - self._serialize_json(report_spec_snapshot), - *[ - version_manifest.get(field) - for field in REPORT_RUN_VERSION_FIELDS - ], - ), - ) + run_sequence_row: Row | None = tx.query( + """ + SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence + FROM report_output_runs + WHERE report_output_id = ? + """, + (report_output_id,), + ).fetchone() + run_sequence = ( + int(run_sequence_row["max_run_sequence"]) + 1 + if run_sequence_row is not None + else 1 + ) - database.transaction(create_run_transaction) - return self.get_report_output_run(run_id) + requested_at = self._utc_timestamp() + is_terminal = status in ("complete", "error") + has_started = status in ("running", "complete", "error") + started_at = requested_at if has_started else None + finished_at = requested_at if is_terminal else None + + tx.query( + f""" + INSERT INTO report_output_runs ( + id, report_output_id, run_sequence, status, output, error_message, + trigger_type, requested_at, started_at, finished_at, source_run_id, + report_spec_snapshot_json, {", ".join(REPORT_RUN_VERSION_FIELDS)} + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + run_id, + report_output_id, + run_sequence, + status, + self._serialize_json(output), + error_message, + trigger_type, + requested_at, + started_at, + finished_at, + source_run_id, + self._serialize_json(report_spec_snapshot), + *[version_manifest.get(field) for field in REPORT_RUN_VERSION_FIELDS], + ), + ) + created_row: Row | None = tx.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (run_id,), + ).fetchone() + if created_row is None: + raise ValueError(f"Report output run #{run_id} not found after create") + return self._parse_run_row(created_row) def get_report_output_run(self, run_id: str) -> dict | None: row: Row | None = database.query( diff --git a/policyengine_api/services/simulation_run_service.py b/policyengine_api/services/simulation_run_service.py index 544aca9c2..5123e61e2 100644 --- a/policyengine_api/services/simulation_run_service.py +++ b/policyengine_api/services/simulation_run_service.py @@ -48,66 +48,105 @@ def create_simulation_run( simulation_spec_snapshot: dict[str, Any] | str | None = None, version_manifest: dict[str, str | None] | None = None, run_id: str | None = None, + ) -> dict: + def create_run_transaction(tx) -> dict: + return self.create_simulation_run_in_transaction( + tx, + simulation_id, + report_output_run_id=report_output_run_id, + input_position=input_position, + status=status, + trigger_type=trigger_type, + output=output, + error_message=error_message, + source_run_id=source_run_id, + simulation_spec_snapshot=simulation_spec_snapshot, + version_manifest=version_manifest, + run_id=run_id, + ) + + return database.transaction(create_run_transaction) + + def create_simulation_run_in_transaction( + self, + tx, + simulation_id: int, + report_output_run_id: str | None = None, + input_position: int | None = None, + status: str = "pending", + trigger_type: str = "initial", + output: dict[str, Any] | list[Any] | str | None = None, + error_message: str | None = None, + source_run_id: str | None = None, + simulation_spec_snapshot: dict[str, Any] | str | None = None, + version_manifest: dict[str, str | None] | None = None, + run_id: str | None = None, + requested_at: str | None = None, + started_at: str | None = None, + finished_at: str | None = None, ) -> dict: run_id = run_id or str(uuid.uuid4()) version_manifest = version_manifest or {} lock_clause = "" if database.local else " FOR UPDATE" - def create_run_transaction(tx) -> None: - parent_row: Row | None = tx.query( - f"SELECT id FROM simulations WHERE id = ?{lock_clause}", - (simulation_id,), - ).fetchone() - if parent_row is None: - raise ValueError(f"Simulation #{simulation_id} not found") - - run_sequence_row: Row | None = tx.query( - """ - SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence - FROM simulation_runs - WHERE simulation_id = ? - """, - (simulation_id,), - ).fetchone() - run_sequence = ( - int(run_sequence_row["max_run_sequence"]) + 1 - if run_sequence_row is not None - else 1 - ) - - tx.query( - f""" - INSERT INTO simulation_runs ( - id, simulation_id, report_output_run_id, input_position, run_sequence, - status, output, error_message, trigger_type, requested_at, started_at, - finished_at, source_run_id, simulation_spec_snapshot_json, - {", ".join(SIMULATION_RUN_VERSION_FIELDS)} - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - run_id, - simulation_id, - report_output_run_id, - input_position, - run_sequence, - status, - self._serialize_json(output), - error_message, - trigger_type, - None, - None, - None, - source_run_id, - self._serialize_json(simulation_spec_snapshot), - *[ - version_manifest.get(field) - for field in SIMULATION_RUN_VERSION_FIELDS - ], - ), - ) + parent_row: Row | None = tx.query( + f"SELECT id FROM simulations WHERE id = ?{lock_clause}", + (simulation_id,), + ).fetchone() + if parent_row is None: + raise ValueError(f"Simulation #{simulation_id} not found") - database.transaction(create_run_transaction) - return self.get_simulation_run(run_id) + run_sequence_row: Row | None = tx.query( + """ + SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence + FROM simulation_runs + WHERE simulation_id = ? + """, + (simulation_id,), + ).fetchone() + run_sequence = ( + int(run_sequence_row["max_run_sequence"]) + 1 + if run_sequence_row is not None + else 1 + ) + + tx.query( + f""" + INSERT INTO simulation_runs ( + id, simulation_id, report_output_run_id, input_position, run_sequence, + status, output, error_message, trigger_type, requested_at, started_at, + finished_at, source_run_id, simulation_spec_snapshot_json, + {", ".join(SIMULATION_RUN_VERSION_FIELDS)} + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + run_id, + simulation_id, + report_output_run_id, + input_position, + run_sequence, + status, + self._serialize_json(output), + error_message, + trigger_type, + requested_at, + started_at, + finished_at, + source_run_id, + self._serialize_json(simulation_spec_snapshot), + *[ + version_manifest.get(field) + for field in SIMULATION_RUN_VERSION_FIELDS + ], + ), + ) + created_row: Row | None = tx.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (run_id,), + ).fetchone() + if created_row is None: + raise ValueError(f"Simulation run #{run_id} not found after create") + return self._parse_run_row(created_row) def get_simulation_run(self, run_id: str) -> dict | None: row: Row | None = database.query( diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index 35ef7a110..3fb6cca2c 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime, timezone from sqlalchemy.engine.row import Row @@ -9,6 +10,7 @@ parse_json_field, serialize_json_field, ) +from policyengine_api.services.simulation_run_service import SimulationRunService from policyengine_api.services.simulation_spec_service import ( SimulationSpec, SimulationSpecService, @@ -18,10 +20,14 @@ class SimulationService: def __init__(self): self.simulation_spec_service = SimulationSpecService() + self.simulation_run_service = SimulationRunService() def _lock_clause(self) -> str: return "" if database.local else " FOR UPDATE" + def _utc_timestamp(self) -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + def _get_simulation_row( self, simulation_id: int, @@ -377,6 +383,65 @@ def _sync_parent_mirror_from_display_run_in_transaction( raise ValueError(f"Simulation #{simulation['id']} not found after sync") return refreshed_simulation + def create_report_rerun_simulation_run_in_transaction( + self, + tx, + simulation: dict, + *, + report_output_run_id: str, + input_position: int, + ) -> dict: + simulation_spec = self._upsert_simulation_spec_in_transaction(tx, simulation) + runs_descending = self._list_simulation_runs_descending( + simulation["id"], + queryer=tx, + ) + source_run = self._select_display_run(simulation, runs_descending) + version_manifest = ( + self._build_existing_run_version_manifest( + source_run, + simulation, + ) + if source_run is not None + else self._build_bootstrap_version_manifest(simulation) + ) + created_run = self.simulation_run_service.create_simulation_run_in_transaction( + tx, + simulation["id"], + report_output_run_id=report_output_run_id, + input_position=input_position, + status="pending", + trigger_type="report_rerun", + source_run_id=source_run["id"] if source_run is not None else None, + simulation_spec_snapshot=simulation_spec.model_dump(), + version_manifest=version_manifest, + requested_at=self._utc_timestamp(), + ) + + simulation["status"] = "pending" + simulation["output"] = None + simulation["error_message"] = None + runs_descending = self._list_simulation_runs_descending( + simulation["id"], + queryer=tx, + ) + self._sync_parent_pointers_in_transaction(tx, simulation, runs_descending) + tx.query( + """ + UPDATE simulations + SET status = ?, output = ?, error_message = ? + WHERE id = ? AND country_id = ? + """, + ( + "pending", + None, + None, + simulation["id"], + simulation["country_id"], + ), + ) + return created_run + def _merge_display_run_into_simulation( self, simulation: dict, From edb771d94b75805f59f7f8201f0217164a33649b Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 20:37:15 +0200 Subject: [PATCH 13/17] Add Alembic migration scaffold --- .claude/skills/database-migrations.md | 35 +++ alembic.ini | 48 ++++ alembic/README | 8 + alembic/env.py | 83 ++++++ alembic/script.py.mako | 28 ++ ...0511_60d38593ddc3_initial_legacy_schema.py | 264 ++++++++++++++++++ ...c54ce35_add_report_run_canonical_schema.py | 72 +++++ policyengine_api/data/alembic_metadata.py | 56 ++++ policyengine_api/data/initialise.sql | 5 + policyengine_api/data/initialise_local.sql | 6 + pyproject.toml | 1 + tests/unit/data/test_run_schema.py | 7 + uv.lock | 28 ++ 13 files changed, 641 insertions(+) create mode 100644 .claude/skills/database-migrations.md create mode 100644 alembic.ini create mode 100644 alembic/README create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/20260511_60d38593ddc3_initial_legacy_schema.py create mode 100644 alembic/versions/20260511_d39d9c54ce35_add_report_run_canonical_schema.py create mode 100644 policyengine_api/data/alembic_metadata.py diff --git a/.claude/skills/database-migrations.md b/.claude/skills/database-migrations.md new file mode 100644 index 000000000..d94ce61dc --- /dev/null +++ b/.claude/skills/database-migrations.md @@ -0,0 +1,35 @@ +# Database Migration Guidelines + +## Overview + +This project uses Alembic for database migrations. API v1 still uses raw SQL +initializers rather than ORM models, so Alembic target metadata is reflected +from `policyengine_api/data/initialise_local.sql` by default. + +## Rules + +- Do not manually author Alembic operations for normal schema changes. +- Generate migrations with `uv run alembic revision --autogenerate`. +- Review generated migrations before applying them. +- Keep SQL initializers and generated migrations aligned. +- For pre-existing production databases, stamp the base revision before applying + new upgrade revisions. + +## Commands + +```bash +uv run alembic revision --autogenerate -m "Description" +uv run alembic upgrade head +uv run alembic current +uv run alembic history +uv run alembic stamp +``` + +## API v1 Notes + +- Set `POLICYENGINE_ALEMBIC_DATABASE_URL` to the database SQLAlchemy URL Alembic + should connect to. +- Set `POLICYENGINE_ALEMBIC_SCHEMA_SQL` when generating against a temporary + schema SQL file instead of the current initializer. +- The base migration should be stamped in production because the tables already + exist there. diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..1404bb8e6 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,48 @@ +# Alembic configuration for PolicyEngine API v1. + +[alembic] +script_location = %(here)s/alembic +file_template = %%(year)d%%(month).2d%%(day).2d_%%(rev)s_%%(slug)s +prepend_sys_path = . +path_separator = os +output_encoding = utf-8 + +# Overridden by alembic/env.py. For local generation, set +# POLICYENGINE_ALEMBIC_DATABASE_URL explicitly. +sqlalchemy.url = sqlite:///policyengine_api/data/policyengine.db + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 000000000..6d8c5892c --- /dev/null +++ b/alembic/README @@ -0,0 +1,8 @@ +PolicyEngine API v1 Alembic migrations. + +This project does not currently use SQLAlchemy ORM models. Alembic +autogenerate reflects target metadata from `policyengine_api/data/initialise_local.sql` +or from the path in `POLICYENGINE_ALEMBIC_SCHEMA_SQL`. + +Use `alembic stamp` for pre-existing production databases before applying +incremental migrations. diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 000000000..620ffe8f4 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,83 @@ +"""Alembic environment for PolicyEngine API v1 raw-SQL schema migrations.""" + +from logging.config import fileConfig +import importlib.util +import os +from pathlib import Path +import sys + +from sqlalchemy import engine_from_config, pool + +from alembic import context + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +metadata_path = ( + Path(__file__).parent.parent / "policyengine_api" / "data" / "alembic_metadata.py" +) +metadata_spec = importlib.util.spec_from_file_location( + "policyengine_api_alembic_metadata", + metadata_path, +) +if metadata_spec is None or metadata_spec.loader is None: + raise RuntimeError(f"Could not load Alembic metadata helper from {metadata_path}") +metadata_module = importlib.util.module_from_spec(metadata_spec) +metadata_spec.loader.exec_module(metadata_module) +build_metadata_from_sql = metadata_module.build_metadata_from_sql + + +config = context.config + +database_url = os.environ.get("POLICYENGINE_ALEMBIC_DATABASE_URL") or os.environ.get( + "DATABASE_URL" +) +if database_url: + config.set_main_option("sqlalchemy.url", database_url) + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +schema_sql_path = os.environ.get("POLICYENGINE_ALEMBIC_SCHEMA_SQL") +target_metadata = build_metadata_from_sql(schema_sql_path) + + +def _configure_context(connection=None, url: str | None = None) -> None: + options = { + "target_metadata": target_metadata, + "compare_type": False, + "compare_server_default": False, + } + if connection is not None: + context.configure(connection=connection, **options) + else: + context.configure( + url=url, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + **options, + ) + + +def run_migrations_offline() -> None: + _configure_context(url=config.get_main_option("sqlalchemy.url")) + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + _configure_context(connection=connection) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 000000000..11016301e --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/20260511_60d38593ddc3_initial_legacy_schema.py b/alembic/versions/20260511_60d38593ddc3_initial_legacy_schema.py new file mode 100644 index 000000000..39bd19baf --- /dev/null +++ b/alembic/versions/20260511_60d38593ddc3_initial_legacy_schema.py @@ -0,0 +1,264 @@ +"""Initial legacy schema + +Revision ID: 60d38593ddc3 +Revises: +Create Date: 2026-05-11 20:19:44.056995 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "60d38593ddc3" +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "analysis", + sa.Column("prompt_id", sa.INTEGER(), nullable=False), + sa.Column("prompt", sa.TEXT(), nullable=False), + sa.Column("analysis", sa.TEXT(), nullable=True), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.PrimaryKeyConstraint("prompt_id"), + ) + op.create_table( + "computed_household", + sa.Column("household_id", sa.INTEGER(), nullable=False), + sa.Column("policy_id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("computed_household_json", sa.JSON(), nullable=False), + sa.Column("status", sa.VARCHAR(length=32), nullable=True), + sa.PrimaryKeyConstraint("household_id", "policy_id", "country_id"), + ) + op.create_table( + "economy", + sa.Column("economy_id", sa.INTEGER(), nullable=False), + sa.Column("policy_id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("region", sa.VARCHAR(length=32), nullable=True), + sa.Column("time_period", sa.VARCHAR(length=32), nullable=True), + sa.Column("options_json", sa.JSON(), nullable=False), + sa.Column("options_hash", sa.VARCHAR(length=255), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("economy_json", sa.JSON(), nullable=True), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.Column("message", sa.VARCHAR(length=255), nullable=True), + sa.PrimaryKeyConstraint("economy_id"), + ) + op.create_table( + "household", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("label", sa.VARCHAR(length=255), nullable=True), + sa.Column("api_version", sa.VARCHAR(length=255), nullable=False), + sa.Column("household_json", sa.JSON(), nullable=False), + sa.Column("household_hash", sa.VARCHAR(length=255), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "legacy_report_output_aliases", + sa.Column("legacy_report_output_id", sa.INTEGER(), nullable=False), + sa.Column("canonical_report_output_id", sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint("legacy_report_output_id"), + ) + op.create_table( + "policy", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("label", sa.VARCHAR(length=255), nullable=True), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("policy_json", sa.JSON(), nullable=False), + sa.Column("policy_hash", sa.VARCHAR(length=255), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "reform_impact", + sa.Column("reform_impact_id", sa.INTEGER(), nullable=False), + sa.Column("baseline_policy_id", sa.INTEGER(), nullable=False), + sa.Column("reform_policy_id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("region", sa.VARCHAR(length=32), nullable=False), + sa.Column("dataset", sa.VARCHAR(length=255), nullable=False), + sa.Column("time_period", sa.VARCHAR(length=32), nullable=False), + sa.Column("options_json", sa.JSON(), nullable=False), + sa.Column("options_hash", sa.VARCHAR(length=255), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("reform_impact_json", sa.JSON(), nullable=False), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.Column("message", sa.VARCHAR(length=255), nullable=True), + sa.Column("start_time", sa.DATETIME(), nullable=False), + sa.Column("end_time", sa.DATETIME(), nullable=True), + sa.Column("execution_id", sa.VARCHAR(length=255), nullable=False), + sa.PrimaryKeyConstraint("reform_impact_id"), + ) + op.create_table( + "report_output_runs", + sa.Column("id", sa.CHAR(length=36), nullable=False), + sa.Column("report_output_id", sa.INTEGER(), nullable=False), + sa.Column("run_sequence", sa.INTEGER(), nullable=False), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.Column("output", sa.JSON(), nullable=True), + sa.Column("error_message", sa.TEXT(), nullable=True), + sa.Column("trigger_type", sa.VARCHAR(length=32), nullable=False), + sa.Column("requested_at", sa.DATETIME(), nullable=True), + sa.Column("started_at", sa.DATETIME(), nullable=True), + sa.Column("finished_at", sa.DATETIME(), nullable=True), + sa.Column("source_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("report_spec_snapshot_json", sa.JSON(), nullable=True), + sa.Column("country_package_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("policyengine_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("data_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("runtime_app_name", sa.VARCHAR(length=255), nullable=True), + sa.Column("report_cache_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("simulation_cache_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("requested_version_override", sa.VARCHAR(length=255), nullable=True), + sa.Column("resolved_dataset", sa.VARCHAR(length=255), nullable=True), + sa.Column("resolved_options_hash", sa.VARCHAR(length=255), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("report_output_id", "run_sequence"), + ) + op.create_table( + "report_outputs", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("simulation_1_id", sa.INTEGER(), nullable=False), + sa.Column("simulation_2_id", sa.INTEGER(), nullable=True), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column( + "status", + sa.VARCHAR(length=32), + server_default=sa.text("'pending'"), + nullable=False, + ), + sa.Column("output", sa.JSON(), nullable=True), + sa.Column("error_message", sa.TEXT(), nullable=True), + sa.Column( + "year", + sa.VARCHAR(length=255), + server_default=sa.text("'2025'"), + nullable=True, + ), + sa.Column("report_kind", sa.VARCHAR(length=64), nullable=True), + sa.Column("report_spec_json", sa.JSON(), nullable=True), + sa.Column("report_spec_schema_version", sa.INTEGER(), nullable=True), + sa.Column("report_spec_status", sa.VARCHAR(length=32), nullable=True), + sa.Column("active_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("latest_successful_run_id", sa.CHAR(length=36), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "simulation_runs", + sa.Column("id", sa.CHAR(length=36), nullable=False), + sa.Column("simulation_id", sa.INTEGER(), nullable=False), + sa.Column("report_output_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("input_position", sa.INTEGER(), nullable=True), + sa.Column("run_sequence", sa.INTEGER(), nullable=False), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.Column("output", sa.JSON(), nullable=True), + sa.Column("error_message", sa.TEXT(), nullable=True), + sa.Column("trigger_type", sa.VARCHAR(length=32), nullable=False), + sa.Column("requested_at", sa.DATETIME(), nullable=True), + sa.Column("started_at", sa.DATETIME(), nullable=True), + sa.Column("finished_at", sa.DATETIME(), nullable=True), + sa.Column("source_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("simulation_spec_snapshot_json", sa.JSON(), nullable=True), + sa.Column("country_package_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("policyengine_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("data_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("runtime_app_name", sa.VARCHAR(length=255), nullable=True), + sa.Column("simulation_cache_version", sa.VARCHAR(length=255), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("simulation_id", "run_sequence"), + ) + op.create_table( + "simulations", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("population_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("population_type", sa.VARCHAR(length=50), nullable=False), + sa.Column("policy_id", sa.INTEGER(), nullable=False), + sa.Column( + "status", + sa.VARCHAR(length=32), + server_default=sa.text("'pending'"), + nullable=False, + ), + sa.Column("output", sa.JSON(), nullable=True), + sa.Column("error_message", sa.TEXT(), nullable=True), + sa.Column("simulation_spec_json", sa.JSON(), nullable=True), + sa.Column("simulation_spec_schema_version", sa.INTEGER(), nullable=True), + sa.Column("active_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("latest_successful_run_id", sa.CHAR(length=36), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "tracers", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("household_id", sa.INTEGER(), nullable=False), + sa.Column("policy_id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("tracer_output", sa.JSON(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "user_policies", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("reform_id", sa.INTEGER(), nullable=False), + sa.Column("reform_label", sa.VARCHAR(length=255), nullable=True), + sa.Column("baseline_id", sa.INTEGER(), nullable=False), + sa.Column("baseline_label", sa.VARCHAR(length=255), nullable=True), + sa.Column("user_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("year", sa.VARCHAR(length=32), nullable=False), + sa.Column("geography", sa.VARCHAR(length=255), nullable=False), + sa.Column("dataset", sa.VARCHAR(length=255), nullable=True), + sa.Column("number_of_provisions", sa.INTEGER(), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=32), nullable=False), + sa.Column("added_date", sa.BIGINT(), nullable=False), + sa.Column("updated_date", sa.BIGINT(), nullable=False), + sa.Column("budgetary_impact", sa.VARCHAR(length=255), nullable=True), + sa.Column("type", sa.VARCHAR(length=255), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "user_profiles", + sa.Column("user_id", sa.INTEGER(), nullable=False), + sa.Column("auth0_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("username", sa.VARCHAR(length=255), nullable=True), + sa.Column("primary_country", sa.VARCHAR(length=3), nullable=False), + sa.Column("user_since", sa.BIGINT(), nullable=False), + sa.PrimaryKeyConstraint("user_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("user_profiles") + op.drop_table("user_policies") + op.drop_table("tracers") + op.drop_table("simulations") + op.drop_table("simulation_runs") + op.drop_table("report_outputs") + op.drop_table("report_output_runs") + op.drop_table("reform_impact") + op.drop_table("policy") + op.drop_table("legacy_report_output_aliases") + op.drop_table("household") + op.drop_table("economy") + op.drop_table("computed_household") + op.drop_table("analysis") + # ### end Alembic commands ### diff --git a/alembic/versions/20260511_d39d9c54ce35_add_report_run_canonical_schema.py b/alembic/versions/20260511_d39d9c54ce35_add_report_run_canonical_schema.py new file mode 100644 index 000000000..ac73db5be --- /dev/null +++ b/alembic/versions/20260511_d39d9c54ce35_add_report_run_canonical_schema.py @@ -0,0 +1,72 @@ +"""Add report run canonical schema + +Revision ID: d39d9c54ce35 +Revises: 60d38593ddc3 +Create Date: 2026-05-11 20:20:06.697209 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "d39d9c54ce35" +down_revision: Union[str, Sequence[str], None] = "60d38593ddc3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "legacy_report_output_id_map", + sa.Column("legacy_report_output_id", sa.INTEGER(), nullable=False), + sa.Column("canonical_report_output_id", sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint("legacy_report_output_id"), + ) + op.create_index( + "legacy_report_output_id_map_canonical_idx", + "legacy_report_output_id_map", + ["canonical_report_output_id"], + unique=False, + ) + op.add_column( + "report_outputs", + sa.Column("report_identity_hash", sa.VARCHAR(length=64), nullable=True), + ) + op.add_column( + "report_outputs", + sa.Column("report_identity_schema_version", sa.INTEGER(), nullable=True), + ) + op.create_index( + "report_outputs_identity_idx", + "report_outputs", + ["country_id", "report_identity_hash", "report_identity_schema_version"], + unique=False, + ) + op.create_index( + "simulation_runs_report_output_run_idx", + "simulation_runs", + ["report_output_run_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("simulation_runs_report_output_run_idx", table_name="simulation_runs") + op.drop_index("report_outputs_identity_idx", table_name="report_outputs") + op.drop_column("report_outputs", "report_identity_schema_version") + op.drop_column("report_outputs", "report_identity_hash") + op.drop_index( + "legacy_report_output_id_map_canonical_idx", + table_name="legacy_report_output_id_map", + ) + op.drop_table("legacy_report_output_id_map") + # ### end Alembic commands ### diff --git a/policyengine_api/data/alembic_metadata.py b/policyengine_api/data/alembic_metadata.py new file mode 100644 index 000000000..580e79867 --- /dev/null +++ b/policyengine_api/data/alembic_metadata.py @@ -0,0 +1,56 @@ +"""Build Alembic target metadata from the existing SQL initializer.""" + +from pathlib import Path +import sqlite3 + +from sqlalchemy import JSON, MetaData, create_engine + + +DEFAULT_SCHEMA_SQL = Path(__file__).with_name("initialise_local.sql") + +JSON_COLUMN_NAMES = { + "computed_household_json", + "economy_json", + "household_json", + "options_json", + "policy_json", + "reform_impact_json", + "report_spec_json", + "report_spec_snapshot_json", + "simulation_spec_json", + "simulation_spec_snapshot_json", + "tracer_output", + "output", +} + + +def _normalize_reflected_metadata(metadata: MetaData) -> None: + for table in metadata.tables.values(): + for column in table.columns: + if column.name in JSON_COLUMN_NAMES: + column.type = JSON() + if column.primary_key: + column.nullable = False + if column.server_default is not None: + default_arg = str(column.server_default.arg).strip().upper() + if "NULL" in default_arg: + column.server_default = None + + +def build_metadata_from_sql(schema_sql_path: str | Path | None = None) -> MetaData: + """Reflect SQL initializer DDL into SQLAlchemy metadata for autogenerate. + + API v1 still uses raw SQL rather than ORM models. This keeps Alembic's + autogenerate path tied to the existing initializer instead of maintaining a + second manually-authored schema definition. + """ + + schema_sql_path = Path(schema_sql_path or DEFAULT_SCHEMA_SQL) + connection = sqlite3.connect(":memory:") + connection.executescript(schema_sql_path.read_text()) + + engine = create_engine("sqlite://", creator=lambda: connection) + metadata = MetaData() + metadata.reflect(bind=engine) + _normalize_reflected_metadata(metadata) + return metadata diff --git a/policyengine_api/data/initialise.sql b/policyengine_api/data/initialise.sql index 917ca0cf0..8c3cf2421 100644 --- a/policyengine_api/data/initialise.sql +++ b/policyengine_api/data/initialise.sql @@ -197,6 +197,11 @@ CREATE TABLE IF NOT EXISTS simulation_runs ( CREATE INDEX simulation_runs_report_output_run_idx ON simulation_runs (report_output_run_id); +CREATE TABLE IF NOT EXISTS legacy_report_output_aliases ( + legacy_report_output_id INT PRIMARY KEY, + canonical_report_output_id INT NOT NULL +); + CREATE TABLE IF NOT EXISTS legacy_report_output_id_map ( legacy_report_output_id INT PRIMARY KEY, canonical_report_output_id INT NOT NULL diff --git a/policyengine_api/data/initialise_local.sql b/policyengine_api/data/initialise_local.sql index d92930257..1093308fc 100644 --- a/policyengine_api/data/initialise_local.sql +++ b/policyengine_api/data/initialise_local.sql @@ -8,6 +8,7 @@ DROP TABLE IF EXISTS user_policies; DROP TABLE IF EXISTS tracers; DROP TABLE IF EXISTS report_output_runs; DROP TABLE IF EXISTS simulation_runs; +DROP TABLE IF EXISTS legacy_report_output_aliases; DROP TABLE IF EXISTS legacy_report_output_id_map; CREATE TABLE IF NOT EXISTS household ( @@ -209,6 +210,11 @@ CREATE TABLE IF NOT EXISTS simulation_runs ( CREATE INDEX simulation_runs_report_output_run_idx ON simulation_runs (report_output_run_id); +CREATE TABLE IF NOT EXISTS legacy_report_output_aliases ( + legacy_report_output_id INT PRIMARY KEY, + canonical_report_output_id INT NOT NULL +); + CREATE TABLE IF NOT EXISTS legacy_report_output_id_map ( legacy_report_output_id INT PRIMARY KEY, canonical_report_output_id INT NOT NULL diff --git a/pyproject.toml b/pyproject.toml index f5063c0c3..8004ad179 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ "License :: OSI Approved :: GNU Affero General Public License v3", ] dependencies = [ + "alembic>=1.13.0", "anthropic", "assertpy", "click>=8,<9", diff --git a/tests/unit/data/test_run_schema.py b/tests/unit/data/test_run_schema.py index 5635220d3..e0cd11a9b 100644 --- a/tests/unit/data/test_run_schema.py +++ b/tests/unit/data/test_run_schema.py @@ -66,6 +66,12 @@ def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): id_map_columns = _column_names(test_db, "legacy_report_output_id_map") assert {"legacy_report_output_id", "canonical_report_output_id"} == id_map_columns + legacy_alias_columns = _column_names(test_db, "legacy_report_output_aliases") + assert { + "legacy_report_output_id", + "canonical_report_output_id", + } == legacy_alias_columns + def test_stage_one_schema_is_defined_in_both_sql_initializers(): sql_paths = [ @@ -76,6 +82,7 @@ def test_stage_one_schema_is_defined_in_both_sql_initializers(): required_snippets = [ "CREATE TABLE IF NOT EXISTS report_output_runs", "CREATE TABLE IF NOT EXISTS simulation_runs", + "CREATE TABLE IF NOT EXISTS legacy_report_output_aliases", "CREATE TABLE IF NOT EXISTS legacy_report_output_id_map", "report_spec_json", "report_spec_status", diff --git a/uv.lock b/uv.lock index e16464992..b99629c78 100644 --- a/uv.lock +++ b/uv.lock @@ -146,6 +146,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "alembic" +version = "1.18.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/13/8b084e0f2efb0275a1d534838844926f798bd766566b1375174e2448cd31/alembic-1.18.4.tar.gz", hash = "sha256:cb6e1fd84b6174ab8dbb2329f86d631ba9559dd78df550b57804d607672cedbc", size = 2056725, upload-time = "2026-02-10T16:00:47.195Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/29/6533c317b74f707ea28f8d633734dbda2119bbadfc61b2f3640ba835d0f7/alembic-1.18.4-py3-none-any.whl", hash = "sha256:a5ed4adcf6d8a4cb575f3d759f071b03cd6e5c7618eb796cb52497be25bfe19a", size = 263893, upload-time = "2026-02-10T16:00:49.997Z" }, +] + [[package]] name = "altair" version = "6.1.0" @@ -1811,6 +1825,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/91/6c074015990f4f656f7b69a5c2d15924906ce0bc19c7014ac953493c0cf0/linecheck-0.1.0-py3-none-any.whl", hash = "sha256:73c6b29790521fa711b00df7cd60af4caf7004337d8710606881fbecb0d1bc83", size = 2767, upload-time = "2022-07-16T13:06:17.01Z" }, ] +[[package]] +name = "mako" +version = "1.3.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/62/791b31e69ae182791ec67f04850f2f062716bbd205483d63a215f3e062d3/mako-1.3.12.tar.gz", hash = "sha256:9f778e93289bd410bb35daadeb4fc66d95a746f0b75777b942088b7fd7af550a", size = 400219, upload-time = "2026-04-28T19:01:08.512Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/b1/a0ec7a5a9db730a08daef1fdfb8090435b82465abbf758a596f0ea88727e/mako-1.3.12-py3-none-any.whl", hash = "sha256:8f61569480282dbf557145ce441e4ba888be453c30989f879f0d652e39f53ea9", size = 78521, upload-time = "2026-04-28T19:01:10.393Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -2625,6 +2651,7 @@ name = "policyengine-api" version = "3.40.12" source = { editable = "." } dependencies = [ + { name = "alembic" }, { name = "anthropic" }, { name = "assertpy" }, { name = "click" }, @@ -2670,6 +2697,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "alembic", specifier = ">=1.13.0" }, { name = "anthropic" }, { name = "assertpy" }, { name = "build", marker = "extra == 'dev'" }, From 8c1f5a63453c152b26f377f9227a7cd12ec35b9b Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 22:26:39 +0200 Subject: [PATCH 14/17] Enforce canonical report identity --- ...decda5_add_report_run_canonical_schema.py} | 8 +- policyengine_api/data/initialise.sql | 2 +- policyengine_api/data/initialise_local.sql | 2 +- .../routes/report_output_routes.py | 14 +- .../services/report_output_service.py | 153 ++++++++--- tests/unit/data/test_run_schema.py | 7 + .../services/test_report_output_service.py | 248 +++++------------- tests/unit/test_stage5_routes.py | 168 +++++++++--- 8 files changed, 344 insertions(+), 258 deletions(-) rename alembic/versions/{20260511_d39d9c54ce35_add_report_run_canonical_schema.py => 20260511_558935decda5_add_report_run_canonical_schema.py} (94%) diff --git a/alembic/versions/20260511_d39d9c54ce35_add_report_run_canonical_schema.py b/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py similarity index 94% rename from alembic/versions/20260511_d39d9c54ce35_add_report_run_canonical_schema.py rename to alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py index ac73db5be..d643b361e 100644 --- a/alembic/versions/20260511_d39d9c54ce35_add_report_run_canonical_schema.py +++ b/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py @@ -1,8 +1,8 @@ """Add report run canonical schema -Revision ID: d39d9c54ce35 +Revision ID: 558935decda5 Revises: 60d38593ddc3 -Create Date: 2026-05-11 20:20:06.697209 +Create Date: 2026-05-11 22:21:20.417733 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. -revision: str = "d39d9c54ce35" +revision: str = "558935decda5" down_revision: Union[str, Sequence[str], None] = "60d38593ddc3" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -46,7 +46,7 @@ def upgrade() -> None: "report_outputs_identity_idx", "report_outputs", ["country_id", "report_identity_hash", "report_identity_schema_version"], - unique=False, + unique=1, ) op.create_index( "simulation_runs_report_output_run_idx", diff --git a/policyengine_api/data/initialise.sql b/policyengine_api/data/initialise.sql index 8c3cf2421..379e3e54a 100644 --- a/policyengine_api/data/initialise.sql +++ b/policyengine_api/data/initialise.sql @@ -141,7 +141,7 @@ CREATE TABLE IF NOT EXISTS report_outputs ( latest_successful_run_id CHAR(36) DEFAULT NULL ); -CREATE INDEX report_outputs_identity_idx +CREATE UNIQUE INDEX report_outputs_identity_idx ON report_outputs ( country_id, report_identity_hash, report_identity_schema_version ); diff --git a/policyengine_api/data/initialise_local.sql b/policyengine_api/data/initialise_local.sql index 1093308fc..fe0cb22c8 100644 --- a/policyengine_api/data/initialise_local.sql +++ b/policyengine_api/data/initialise_local.sql @@ -154,7 +154,7 @@ CREATE TABLE IF NOT EXISTS report_outputs ( latest_successful_run_id CHAR(36) DEFAULT NULL ); -CREATE INDEX report_outputs_identity_idx +CREATE UNIQUE INDEX report_outputs_identity_idx ON report_outputs ( country_id, report_identity_hash, report_identity_schema_version ); diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index cfafca18d..1b85d72ef 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -105,6 +105,7 @@ def create_report_output(country_id: str) -> Response: country_id=country_id, explicit_report_spec=report_spec, report_spec_schema_version=report_spec_schema_version, + ensure_current_report_cache_run=True, ) ) # Report already exists, return it with 200 status @@ -309,7 +310,18 @@ def update_report_output(country_id: str) -> Response: if not success: raise BadRequest("No fields to update") - updated_report = report_output_service.get_report_output(country_id, report_id) + if report_output_run_id is not None: + updated_report = report_output_service.get_report_output_for_run( + country_id, + report_id, + report_output_run_id, + ) + else: + updated_report = report_output_service.get_report_output( + country_id, report_id + ) + if updated_report is None: + raise NotFound(f"Report #{report_id} not found.") response_body = dict( status="ok", diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 1372cdb13..44e36b4d4 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -652,6 +652,50 @@ def _insert_bootstrap_report_run( version_manifest=version_manifest, ) + def _ensure_current_report_cache_run_in_transaction( + self, + tx, + report_output: dict, + report_spec: ReportSpec | None, + simulation_1: dict | None, + simulation_2: dict | None, + runs_descending: list[dict], + ) -> list[dict]: + current_report_cache_version = get_report_output_cache_version( + report_output["country_id"] + ) + if any( + run.get("report_cache_version") == current_report_cache_version + for run in runs_descending + ): + return runs_descending + + source_run = select_display_report_run(report_output, runs_descending) + version_manifest = self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + version_manifest["report_cache_version"] = current_report_cache_version + + self.report_run_service.create_report_output_run_in_transaction( + tx, + report_output["id"], + status="pending", + trigger_type="rerun", + source_run_id=(source_run["id"] if source_run is not None else None), + report_spec_snapshot=( + report_spec.model_dump() if report_spec is not None else None + ), + version_manifest=version_manifest, + ) + report_output["status"] = "pending" + report_output["output"] = None + report_output["error_message"] = None + report_output["api_version"] = current_report_cache_version + return self._list_report_runs_descending(report_output["id"], queryer=tx) + def _update_report_run_in_transaction( self, tx, @@ -811,6 +855,7 @@ def _ensure_report_output_dual_write_state_in_transaction( explicit_report_spec: ReportSpec | None = None, report_spec_schema_version: int | None = None, version_manifest_overrides: dict[str, str | None] | None = None, + ensure_current_report_cache_run: bool = False, ) -> dict: report_output = self._get_report_output_row( report_output_id, @@ -911,6 +956,26 @@ def _ensure_report_output_dual_write_state_in_transaction( report_output_id, queryer=tx ) + if ensure_current_report_cache_run: + runs_descending = self._ensure_current_report_cache_run_in_transaction( + tx, + report_output, + report_spec, + simulation_1, + simulation_2, + runs_descending, + ) + refreshed_report_output = ( + self._sync_parent_mirror_from_display_run_in_transaction( + tx, + report_output, + runs_descending, + ) + ) + return self._with_display_run_timestamps( + refreshed_report_output, queryer=tx + ) + self._sync_parent_pointers_in_transaction(tx, report_output, runs_descending) refreshed_report_output = self._get_report_output_row( report_output_id, @@ -928,6 +993,7 @@ def ensure_report_output_dual_write_state( explicit_report_spec: ReportSpec | None = None, report_spec_schema_version: int | None = None, version_manifest_overrides: dict[str, str | None] | None = None, + ensure_current_report_cache_run: bool = False, ) -> dict: return database.transaction( lambda tx: self._ensure_report_output_dual_write_state_in_transaction( @@ -937,6 +1003,7 @@ def ensure_report_output_dual_write_state( explicit_report_spec=explicit_report_spec, report_spec_schema_version=report_spec_schema_version, version_manifest_overrides=version_manifest_overrides, + ensure_current_report_cache_run=ensure_current_report_cache_run, ) ) @@ -993,18 +1060,17 @@ def _find_existing_report_output_row( queryer=None, ) -> dict | None: queryer = queryer or database - api_version = get_report_output_cache_version(country_id) query = """ SELECT * FROM report_outputs - WHERE country_id = ? AND simulation_1_id = ? AND year = ? AND api_version = ? + WHERE country_id = ? AND simulation_1_id = ? AND year = ? """ - params: list[int | str] = [country_id, simulation_1_id, year, api_version] + params: list[int | str] = [country_id, simulation_1_id, year] if simulation_2_id is not None: query += " AND simulation_2_id = ?" params.append(simulation_2_id) else: query += " AND simulation_2_id IS NULL" - query += " ORDER BY id DESC" + query += " ORDER BY id ASC" row = queryer.query(query, tuple(params)).fetchone() return dict(row) if row is not None else None @@ -1262,37 +1328,6 @@ def _merge_display_run_into_report_output( result[field] = self._format_run_timestamp(display_run.get(field)) return result - def find_existing_report_output( - self, - country_id: str, - simulation_1_id: int, - simulation_2_id: int | None = None, - year: str = "2025", - ) -> dict | None: - """ - Find an existing report output with the same simulation IDs and year. - """ - print("Checking for existing report output") - - try: - existing_report = self._find_existing_report_output_row( - country_id=country_id, - simulation_1_id=simulation_1_id, - simulation_2_id=simulation_2_id, - year=year, - ) - if existing_report is not None: - print(f"Found existing report output with ID: {existing_report['id']}") - return self.ensure_report_output_dual_write_state( - existing_report["id"], - country_id=country_id, - ) - return None - - except Exception as e: - print(f"Error checking for existing report output. Details: {str(e)}") - raise e - def find_existing_report_output_for_create( self, country_id: str, @@ -1358,6 +1393,7 @@ def tx_callback(tx): country_id=country_id, explicit_report_spec=report_spec, report_spec_schema_version=report_spec_schema_version, + ensure_current_report_cache_run=True, ) self._require_simulation_exists( @@ -1489,6 +1525,53 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No ) raise e + def get_report_output_for_run( + self, + country_id: str, + report_output_id: int, + report_output_run_id: str, + ) -> dict | None: + """ + Get a report output projected through one explicit run. + + Normal report reads intentionally apply display-run selection. PATCH + responses for an explicit run need the narrower projection so workers + see the run they just updated, even if it is not the report's display + run. + """ + resolution = self.report_output_id_map_service.resolve_report_output_id( + report_output_id, + country_id=country_id, + ) + if resolution is None: + return None + + canonical_report_output_id = resolution["canonical_report_output_id"] + canonical_report_output = self._get_report_output_row( + canonical_report_output_id, + country_id=country_id, + ) + if canonical_report_output is None: + return None + + explicit_run = self._get_report_run_row( + report_output_run_id, + report_output_id=canonical_report_output_id, + ) + if explicit_run is None: + return None + + resolved_report_output = self._merge_display_run_into_report_output( + canonical_report_output, + explicit_run, + ) + if resolution["is_legacy_id"]: + return self._with_requested_report_output_id( + report_output_id, + resolved_report_output, + ) + return resolved_report_output + def create_report_rerun( self, country_id: str, diff --git a/tests/unit/data/test_run_schema.py b/tests/unit/data/test_run_schema.py index e0cd11a9b..531b0af17 100644 --- a/tests/unit/data/test_run_schema.py +++ b/tests/unit/data/test_run_schema.py @@ -8,6 +8,11 @@ def _column_names(test_db, table_name: str) -> set[str]: return {row["name"] for row in rows} +def _index_is_unique(test_db, table_name: str, index_name: str) -> bool: + rows = test_db.query(f"PRAGMA index_list({table_name})").fetchall() + return any(row["name"] == index_name and row["unique"] == 1 for row in rows) + + def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): report_output_columns = _column_names(test_db, "report_outputs") assert { @@ -71,6 +76,7 @@ def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): "legacy_report_output_id", "canonical_report_output_id", } == legacy_alias_columns + assert _index_is_unique(test_db, "report_outputs", "report_outputs_identity_idx") def test_stage_one_schema_is_defined_in_both_sql_initializers(): @@ -84,6 +90,7 @@ def test_stage_one_schema_is_defined_in_both_sql_initializers(): "CREATE TABLE IF NOT EXISTS simulation_runs", "CREATE TABLE IF NOT EXISTS legacy_report_output_aliases", "CREATE TABLE IF NOT EXISTS legacy_report_output_id_map", + "CREATE UNIQUE INDEX report_outputs_identity_idx", "report_spec_json", "report_spec_status", "report_identity_hash", diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 312a295b8..c26469a78 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -11,8 +11,6 @@ from policyengine_api.services.run_sync_utils import select_display_report_run from policyengine_api.services.simulation_service import SimulationService -from tests.fixtures.services import report_output_fixtures - pytest_plugins = ("tests.fixtures.services.report_output_fixtures",) service = ReportOutputService() @@ -80,130 +78,6 @@ def test_select_display_run_uses_matching_result_before_newest_fallback(self): assert selected_run["id"] == "matching" -class TestFindExistingReportOutput: - """Test finding existing report outputs in the database.""" - - def test_find_existing_report_output_found(self, test_db, existing_report_record): - """Test finding an existing report output.""" - # GIVEN an existing report record (from fixture) - - # WHEN we search for a report with matching simulation IDs - result = service.find_existing_report_output( - country_id=existing_report_record["country_id"], - simulation_1_id=existing_report_record["simulation_1_id"], - simulation_2_id=existing_report_record["simulation_2_id"], - ) - - # THEN the result should contain the existing report - assert result is not None - assert result["id"] == existing_report_record["id"] - assert ( - result["country_id"] - == report_output_fixtures.valid_report_data["country_id"] - ) - assert result["simulation_1_id"] == existing_report_record["simulation_1_id"] - assert result["status"] == existing_report_record["status"] - - def test_find_existing_report_output_not_found(self, test_db): - """Test that None is returned when no report exists.""" - # GIVEN an empty database - - # WHEN we search for a non-existent report - result = service.find_existing_report_output( - country_id="us", - simulation_1_id=999, - simulation_2_id=888, - year="2025", - ) - - # THEN None should be returned - assert result is None - - def test_find_existing_report_output_with_null_simulation2(self, test_db): - """Test finding reports where simulation_2_id is NULL.""" - api_version = get_report_output_cache_version("us") - # GIVEN a report with NULL simulation_2_id - test_db.query( - "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 100, None, "complete", api_version, "2025"), - ) - - # WHEN we search for it - result = service.find_existing_report_output( - country_id="us", - simulation_1_id=100, - simulation_2_id=None, - year="2025", - ) - - # THEN we should find it - assert result is not None - assert result["simulation_1_id"] == 100 - assert result["simulation_2_id"] is None - assert result["year"] == "2025" - - def test_find_existing_report_output_with_year(self, test_db): - """Test finding reports with different years.""" - api_version = get_report_output_cache_version("us") - # GIVEN reports with different years for the same simulation - test_db.query( - "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 101, None, "complete", api_version, "2025"), - ) - test_db.query( - "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 101, None, "complete", api_version, "2024"), - ) - - # WHEN we search for the 2025 report - result_2025 = service.find_existing_report_output( - country_id="us", - simulation_1_id=101, - simulation_2_id=None, - year="2025", - ) - - # THEN we should find the 2025 report - assert result_2025 is not None - assert result_2025["simulation_1_id"] == 101 - assert result_2025["year"] == "2025" - - # WHEN we search for the 2024 report - result_2024 = service.find_existing_report_output( - country_id="us", - simulation_1_id=101, - simulation_2_id=None, - year="2024", - ) - - # THEN we should find the 2024 report - assert result_2024 is not None - assert result_2024["simulation_1_id"] == 101 - assert result_2024["year"] == "2024" - - # AND the two reports should have different IDs - assert result_2025["id"] != result_2024["id"] - - def test_find_existing_report_output_ignores_stale_runtime_version(self, test_db): - current_version = get_report_output_cache_version("us") - stale_version = "r0stale1" - assert stale_version != current_version - - test_db.query( - "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 102, None, "complete", stale_version, "2025"), - ) - - result = service.find_existing_report_output( - country_id="us", - simulation_1_id=102, - simulation_2_id=None, - year="2025", - ) - - assert result is None - - class TestCreateReportOutput: """Test creating new report outputs in the database.""" @@ -676,6 +550,76 @@ def fail_legacy_key_lookup(**_kwargs): assert created_report["simulation_2_id"] == reform_simulation["id"] assert created_report["report_identity_hash"] is not None + def test_create_report_output_reuses_stale_report_and_adds_current_run( + self, test_db + ): + stale_version = "r0stale1" + current_version = get_report_output_cache_version("us") + assert stale_version != current_version + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_create_stale_runtime", + population_type="household", + policy_id=41, + ) + test_db.query( + """ + INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "us", + simulation["id"], + None, + "complete", + json.dumps({"result": "stale"}), + stale_version, + "2026", + ), + ) + stale_report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + result = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + assert result["id"] == stale_report["id"] + assert result["status"] == "pending" + assert result["output"] is None + assert result["api_version"] == current_version + + report_rows = test_db.query( + """ + SELECT * FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND year = ? + """, + ("us", simulation["id"], "2026"), + ).fetchall() + assert len(report_rows) == 1 + + runs = test_db.query( + """ + SELECT * FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence ASC + """, + (stale_report["id"],), + ).fetchall() + assert len(runs) == 2 + assert runs[0]["status"] == "complete" + assert runs[0]["output"] == json.dumps({"result": "stale"}) + assert runs[0]["report_cache_version"] == stale_version + assert runs[1]["status"] == "pending" + assert runs[1]["output"] is None + assert runs[1]["report_cache_version"] == current_version + assert runs[1]["source_run_id"] == runs[0]["id"] + def test_find_existing_for_create_validates_explicit_spec_context_before_reuse( self, test_db ): @@ -1387,38 +1331,6 @@ def test_get_report_output_bootstraps_running_legacy_run_started_at(self, test_d assert run["started_at"] is not None assert run["finished_at"] is None - def test_find_existing_report_output_backfills_missing_timestamps(self, test_db): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_report_legacy_timestamp_find", - population_type="household", - policy_id=50, - ) - report = service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - test_db.query( - """ - UPDATE report_output_runs - SET requested_at = NULL - WHERE report_output_id = ? - """, - (report["id"],), - ) - - result = service.find_existing_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - - assert result is not None - assert result["requested_at"] is not None - def test_get_report_output_uses_selected_display_run_for_canonical_parent( self, test_db ): @@ -1485,26 +1397,6 @@ def test_get_report_output_resolves_legacy_id_to_canonical_display_run( status="complete", output=json.dumps({"budget": {"budgetary_impact": 3}}), ) - test_db.query( - """ - INSERT INTO report_outputs ( - id, country_id, simulation_1_id, simulation_2_id, status, output, api_version, year, - report_identity_hash, report_identity_schema_version - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - 999, - "us", - simulation["id"], - None, - "error", - json.dumps({"legacy": True}), - "r0legacy1", - "2025", - canonical_report["report_identity_hash"], - canonical_report["report_identity_schema_version"], - ), - ) id_map_service.set_mapping( legacy_report_output_id=999, canonical_report_output_id=canonical_report["id"], diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index 99636a242..a1b58ed05 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -128,6 +128,76 @@ def test_create_report_output_existing_row_repairs_dual_write_state(test_db): assert snapshot["report_kind"] == "household_single" +def test_create_report_output_existing_stale_row_adds_current_run(test_db): + stale_version = "r0stale1" + current_version = get_report_output_cache_version("us") + assert stale_version != current_version + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_stale_report_create", + population_type="household", + policy_id=42, + ) + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, output, year + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "us", + simulation["id"], + None, + stale_version, + "complete", + json.dumps({"result": "stale"}), + "2026", + ), + ) + stale_report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + client = create_test_client() + response = client.post( + "/us/report", + json={ + "simulation_1_id": simulation["id"], + "year": "2026", + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == stale_report["id"] + assert payload["result"]["status"] == "pending" + assert payload["result"]["output"] is None + assert payload["result"]["api_version"] == current_version + + report_rows = test_db.query( + """ + SELECT * FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND year = ? + """, + ("us", simulation["id"], "2026"), + ).fetchall() + assert len(report_rows) == 1 + + runs = test_db.query( + """ + SELECT * FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence ASC + """, + (stale_report["id"],), + ).fetchall() + assert len(runs) == 2 + assert runs[0]["report_cache_version"] == stale_version + assert runs[0]["output"] == json.dumps({"result": "stale"}) + assert runs[1]["report_cache_version"] == current_version + assert runs[1]["status"] == "pending" + + def test_post_report_output_returns_timestamp_fields_for_new_and_existing_report( test_db, ): @@ -1114,6 +1184,66 @@ def test_patch_report_output_explicit_run_id_updates_only_that_run(test_db): assert rerun_after["output"] == json.dumps({"result": "explicit rerun"}) +def test_patch_report_output_explicit_run_id_response_uses_that_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_explicit_report_run_response", + population_type="household", + policy_id=78, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"result": "initial"}), + ) + older_run = report_run_service.create_report_output_run( + report["id"], + status="complete", + trigger_type="rerun", + output=json.dumps({"result": "older before patch"}), + ) + newer_run = report_run_service.create_report_output_run( + report["id"], + status="complete", + trigger_type="rerun", + output=json.dumps({"result": "newer display"}), + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report["id"], + "report_output_run_id": older_run["id"], + "status": "complete", + "output": json.dumps({"result": "older patched"}), + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["output"] == json.dumps({"result": "older patched"}) + assert payload["result"]["finished_at"] is not None + + get_response = client.get(f"/us/report/{report['id']}") + assert get_response.status_code == 200 + get_payload = get_response.get_json() + assert get_payload["result"]["output"] == json.dumps({"result": "newer display"}) + + stored_report = test_db.query( + "SELECT latest_successful_run_id FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + assert stored_report["latest_successful_run_id"] == newer_run["id"] + + def test_patch_report_output_explicit_run_id_through_legacy_id_updates_canonical_run( test_db, ): @@ -1142,25 +1272,6 @@ def test_patch_report_output_explicit_run_id_through_legacy_id_updates_canonical rerun = report_run_service.create_report_output_run( canonical_report["id"], trigger_type="rerun" ) - test_db.query( - """ - INSERT INTO report_outputs ( - id, country_id, simulation_1_id, simulation_2_id, api_version, status, year, - report_identity_hash, report_identity_schema_version - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - 3002, - "us", - simulation["id"], - None, - "legacy-report-cache-version", - "complete", - "2026", - canonical_report["report_identity_hash"], - canonical_report["report_identity_schema_version"], - ), - ) report_output_id_map_service.set_mapping( legacy_report_output_id=3002, canonical_report_output_id=canonical_report["id"], @@ -1264,25 +1375,6 @@ def test_create_report_rerun_via_legacy_id_creates_canonical_linked_runs(test_db status="complete", output=json.dumps({"result": "canonical"}), ) - test_db.query( - """ - INSERT INTO report_outputs ( - id, country_id, simulation_1_id, simulation_2_id, api_version, status, year, - report_identity_hash, report_identity_schema_version - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - 3001, - "us", - simulation["id"], - None, - "legacy-report-cache-version", - "complete", - "2026", - canonical_report["report_identity_hash"], - canonical_report["report_identity_schema_version"], - ), - ) report_output_id_map_service.set_mapping( legacy_report_output_id=3001, canonical_report_output_id=canonical_report["id"], From 94228b3392afd546d77e52dd9c16fc5b83562827 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 12 May 2026 14:45:52 +0200 Subject: [PATCH 15/17] Clarify report identity contract --- ...5decda5_add_report_run_canonical_schema.py | 2 +- policyengine_api/data/initialise.sql | 2 +- policyengine_api/data/initialise_local.sql | 2 +- policyengine_api/routes/simulation_routes.py | 17 +- .../services/report_output_id_map_service.py | 6 +- .../services/report_spec_service.py | 48 +++++- .../services/simulation_service.py | 27 ++++ tests/unit/data/test_run_schema.py | 14 +- .../unit/services/test_report_spec_service.py | 148 ++++++++++++++---- tests/unit/test_stage5_routes.py | 79 ++++++++++ 10 files changed, 295 insertions(+), 50 deletions(-) diff --git a/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py b/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py index d643b361e..4e1e744df 100644 --- a/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py +++ b/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py @@ -46,7 +46,7 @@ def upgrade() -> None: "report_outputs_identity_idx", "report_outputs", ["country_id", "report_identity_hash", "report_identity_schema_version"], - unique=1, + unique=False, ) op.create_index( "simulation_runs_report_output_run_idx", diff --git a/policyengine_api/data/initialise.sql b/policyengine_api/data/initialise.sql index 379e3e54a..8c3cf2421 100644 --- a/policyengine_api/data/initialise.sql +++ b/policyengine_api/data/initialise.sql @@ -141,7 +141,7 @@ CREATE TABLE IF NOT EXISTS report_outputs ( latest_successful_run_id CHAR(36) DEFAULT NULL ); -CREATE UNIQUE INDEX report_outputs_identity_idx +CREATE INDEX report_outputs_identity_idx ON report_outputs ( country_id, report_identity_hash, report_identity_schema_version ); diff --git a/policyengine_api/data/initialise_local.sql b/policyengine_api/data/initialise_local.sql index fe0cb22c8..1093308fc 100644 --- a/policyengine_api/data/initialise_local.sql +++ b/policyengine_api/data/initialise_local.sql @@ -154,7 +154,7 @@ CREATE TABLE IF NOT EXISTS report_outputs ( latest_successful_run_id CHAR(36) DEFAULT NULL ); -CREATE UNIQUE INDEX report_outputs_identity_idx +CREATE INDEX report_outputs_identity_idx ON report_outputs ( country_id, report_identity_hash, report_identity_schema_version ); diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index 19cafe473..375e1d194 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -230,10 +230,19 @@ def update_simulation(country_id: str) -> Response: if not success: raise BadRequest("No fields to update") - # Get the updated record - updated_simulation = simulation_service.get_simulation( - country_id, simulation_id - ) + if simulation_run_id is not None: + updated_simulation = simulation_service.get_simulation_for_run( + country_id, + simulation_id, + simulation_run_id, + ) + else: + updated_simulation = simulation_service.get_simulation( + country_id, + simulation_id, + ) + if updated_simulation is None: + raise NotFound(f"Simulation #{simulation_id} not found.") response_body = dict( status="ok", diff --git a/policyengine_api/services/report_output_id_map_service.py b/policyengine_api/services/report_output_id_map_service.py index 80b341f3a..507f0570a 100644 --- a/policyengine_api/services/report_output_id_map_service.py +++ b/policyengine_api/services/report_output_id_map_service.py @@ -95,13 +95,17 @@ def resolve_report_output_id( canonical_report_output = self._get_report_output_row( canonical_report_output_id, queryer=queryer, - country_id=country_id, ) if canonical_report_output is None: raise ValueError( "Legacy ID mapping points to missing canonical report output " f"#{canonical_report_output_id}" ) + if ( + country_id is not None + and canonical_report_output["country_id"] != country_id + ): + return None return { "requested_report_output_id": requested_report_output_id, "canonical_report_output_id": canonical_report_output_id, diff --git a/policyengine_api/services/report_spec_service.py b/policyengine_api/services/report_spec_service.py index 457d0dc0e..3a3134232 100644 --- a/policyengine_api/services/report_spec_service.py +++ b/policyengine_api/services/report_spec_service.py @@ -379,23 +379,55 @@ def canonicalize_report_spec_for_identity( self, report_spec: ReportSpec, schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> dict[str, Any]: + return self.build_report_identity_document( + report_spec, + schema_version=schema_version, + ) + + def build_report_identity_document( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, ) -> dict[str, Any]: self._validate_report_identity_schema_version(schema_version) - canonical_spec = report_spec.model_dump() - if ( - isinstance(report_spec, EconomyReportSpec) - and report_spec.country_id == "us" - ): - canonical_spec["region"] = normalize_us_region(canonical_spec["region"]) - return canonical_spec + identity_document: dict[str, Any] = { + "schema_version": schema_version, + "country_id": report_spec.country_id, + "report_kind": report_spec.report_kind, + "time_period": report_spec.time_period, + } + if isinstance(report_spec, HouseholdReportSpec): + identity_document["inputs"] = { + "simulation_1": report_spec.simulation_1.model_dump(), + "simulation_2": ( + report_spec.simulation_2.model_dump() + if report_spec.simulation_2 is not None + else None + ), + } + return identity_document + + region = report_spec.region + if report_spec.country_id == "us": + region = normalize_us_region(region) + identity_document["inputs"] = { + "region": region, + "baseline_policy_id": report_spec.baseline_policy_id, + "reform_policy_id": report_spec.reform_policy_id, + "dataset": report_spec.dataset, + "target": report_spec.target, + "options": report_spec.options, + } + return identity_document def serialize_canonical_report_spec_for_identity( self, report_spec: ReportSpec, schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, ) -> str: - canonical_spec = self.canonicalize_report_spec_for_identity( + canonical_spec = self.build_report_identity_document( report_spec, schema_version=schema_version, ) diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index 3fb6cca2c..865a6e5e9 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -697,6 +697,33 @@ def get_simulation(self, country_id: str, simulation_id: int) -> dict | None: print(f"Error fetching simulation #{simulation_id}. Details: {str(e)}") raise e + def get_simulation_for_run( + self, + country_id: str, + simulation_id: int, + simulation_run_id: str, + ) -> dict | None: + """ + Get a simulation projected through one explicit run. + + Normal simulation reads intentionally apply display-run selection. PATCH + responses for an explicit run need the narrower projection so workers + see the run they just updated, even if it is not the simulation's + display run. + """ + simulation = self._get_simulation_row(simulation_id, country_id=country_id) + if simulation is None: + return None + + explicit_run = self._get_simulation_run_row( + simulation_run_id, + simulation_id=simulation_id, + ) + if explicit_run is None: + return None + + return self._merge_display_run_into_simulation(simulation, explicit_run) + def update_simulation( self, country_id: str, diff --git a/tests/unit/data/test_run_schema.py b/tests/unit/data/test_run_schema.py index 531b0af17..e6c7d3db6 100644 --- a/tests/unit/data/test_run_schema.py +++ b/tests/unit/data/test_run_schema.py @@ -13,6 +13,11 @@ def _index_is_unique(test_db, table_name: str, index_name: str) -> bool: return any(row["name"] == index_name and row["unique"] == 1 for row in rows) +def _index_exists(test_db, table_name: str, index_name: str) -> bool: + rows = test_db.query(f"PRAGMA index_list({table_name})").fetchall() + return any(row["name"] == index_name for row in rows) + + def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): report_output_columns = _column_names(test_db, "report_outputs") assert { @@ -76,7 +81,12 @@ def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): "legacy_report_output_id", "canonical_report_output_id", } == legacy_alias_columns - assert _index_is_unique(test_db, "report_outputs", "report_outputs_identity_idx") + assert _index_exists(test_db, "report_outputs", "report_outputs_identity_idx") + assert not _index_is_unique( + test_db, + "report_outputs", + "report_outputs_identity_idx", + ) def test_stage_one_schema_is_defined_in_both_sql_initializers(): @@ -90,7 +100,7 @@ def test_stage_one_schema_is_defined_in_both_sql_initializers(): "CREATE TABLE IF NOT EXISTS simulation_runs", "CREATE TABLE IF NOT EXISTS legacy_report_output_aliases", "CREATE TABLE IF NOT EXISTS legacy_report_output_id_map", - "CREATE UNIQUE INDEX report_outputs_identity_idx", + "CREATE INDEX report_outputs_identity_idx", "report_spec_json", "report_spec_status", "report_identity_hash", diff --git a/tests/unit/services/test_report_spec_service.py b/tests/unit/services/test_report_spec_service.py index 27d0a22fa..4cc0259e2 100644 --- a/tests/unit/services/test_report_spec_service.py +++ b/tests/unit/services/test_report_spec_service.py @@ -471,28 +471,102 @@ def test_rejects_unsupported_schema_version_on_read(self, test_db): class TestReportIdentity: - def test_canonical_identity_reuses_normalized_us_region(self): + def test_builds_household_identity_document(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_comparison", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 2, + }, + } + ) + + identity_document = report_spec_service.build_report_identity_document( + report_spec + ) + + assert identity_document == { + "schema_version": REPORT_IDENTITY_SCHEMA_VERSION, + "country_id": "uk", + "report_kind": "household_comparison", + "time_period": "2027", + "inputs": { + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 2, + }, + }, + } + + def test_builds_economy_identity_document_with_normalized_region(self): report_spec = EconomyReportSpec.model_validate( { "country_id": "us", - "report_kind": "economy_single", + "report_kind": "economy_comparison", "time_period": "2027", "region": "ca", "baseline_policy_id": 10, - "reform_policy_id": 10, - "dataset": "default", - "target": "general", - "options": {}, + "reform_policy_id": 11, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, } ) - canonical_spec = report_spec_service.canonicalize_report_spec_for_identity( + identity_document = report_spec_service.build_report_identity_document( report_spec ) - assert canonical_spec["region"] == "state/ca" + assert identity_document == { + "schema_version": REPORT_IDENTITY_SCHEMA_VERSION, + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "inputs": { + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + } + + def test_canonicalize_helper_returns_identity_document(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_single", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": None, + } + ) - def test_equal_specs_produce_equal_hashes_despite_json_key_order(self): + assert report_spec_service.canonicalize_report_spec_for_identity( + report_spec + ) == report_spec_service.build_report_identity_document(report_spec) + + def test_equal_specs_produce_equal_hashes_despite_nested_options_key_order(self): first_spec = EconomyReportSpec.model_validate( { "country_id": "us", @@ -503,7 +577,7 @@ def test_equal_specs_produce_equal_hashes_despite_json_key_order(self): "reform_policy_id": 11, "dataset": "default", "target": "general", - "options": {"b": 2, "a": 1}, + "options": {"outer": {"b": 2, "a": 1}, "enabled": True}, } ) second_spec = EconomyReportSpec.model_validate( @@ -516,7 +590,7 @@ def test_equal_specs_produce_equal_hashes_despite_json_key_order(self): "reform_policy_id": 11, "dataset": "default", "target": "general", - "options": {"a": 1, "b": 2}, + "options": {"enabled": True, "outer": {"a": 1, "b": 2}}, } ) @@ -524,31 +598,41 @@ def test_equal_specs_produce_equal_hashes_despite_json_key_order(self): first_spec ) == report_spec_service.get_report_identity_hash(second_spec) - def test_distinct_economy_dataset_changes_identity_hash(self): + @pytest.mark.parametrize( + ("field_name", "replacement_value"), + [ + ("time_period", "2028"), + ("region", "state/ny"), + ("baseline_policy_id", 12), + ("reform_policy_id", 13), + ("dataset", "enhanced_us_household"), + ("target", "cliff"), + ("options", {"view": "tax"}), + ], + ) + def test_distinct_economy_definition_fields_change_identity_hash( + self, + field_name, + replacement_value, + ): + base_spec_data = { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {}, + } first_spec = EconomyReportSpec.model_validate( - { - "country_id": "us", - "report_kind": "economy_comparison", - "time_period": "2027", - "region": "state/ca", - "baseline_policy_id": 10, - "reform_policy_id": 11, - "dataset": "default", - "target": "general", - "options": {}, - } + base_spec_data ) second_spec = EconomyReportSpec.model_validate( { - "country_id": "us", - "report_kind": "economy_comparison", - "time_period": "2027", - "region": "state/ca", - "baseline_policy_id": 10, - "reform_policy_id": 11, - "dataset": "enhanced_us_household", - "target": "general", - "options": {}, + **base_spec_data, + field_name: replacement_value, } ) diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index a1b58ed05..42d9613c2 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -687,6 +687,61 @@ def test_patch_simulation_explicit_run_id_updates_only_that_run(test_db): assert rerun_after["output"] == json.dumps({"result": "explicit rerun"}) +def test_patch_simulation_explicit_run_id_response_uses_that_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_explicit_simulation_run_response", + population_type="household", + policy_id=79, + ) + simulation_service.update_simulation( + country_id="us", + simulation_id=simulation["id"], + status="complete", + output=json.dumps({"result": "initial"}), + ) + older_run = simulation_run_service.create_simulation_run( + simulation["id"], + status="complete", + trigger_type="rerun", + output=json.dumps({"result": "older before patch"}), + ) + newer_run = simulation_run_service.create_simulation_run( + simulation["id"], + status="complete", + trigger_type="rerun", + output=json.dumps({"result": "newer display"}), + ) + test_db.query( + """ + UPDATE simulations + SET latest_successful_run_id = ? + WHERE id = ? + """, + (newer_run["id"], simulation["id"]), + ) + + client = create_test_client() + response = client.patch( + "/us/simulation", + json={ + "id": simulation["id"], + "simulation_run_id": older_run["id"], + "status": "complete", + "output": json.dumps({"result": "older patched"}), + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["output"] == json.dumps({"result": "older patched"}) + + get_response = client.get(f"/us/simulation/{simulation['id']}") + assert get_response.status_code == 200 + get_payload = get_response.get_json() + assert get_payload["result"]["output"] == json.dumps({"result": "newer display"}) + + def test_patch_simulation_rejects_non_string_run_metadata(test_db): simulation = simulation_service.create_simulation( country_id="us", @@ -726,6 +781,30 @@ def test_get_report_output_wrong_country_returns_not_found(test_db): assert response.status_code == 404 +def test_get_report_output_legacy_id_wrong_country_returns_not_found(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_alias_wrong_country", + population_type="household", + policy_id=56, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_id_map_service.set_mapping( + legacy_report_output_id=2000, + canonical_report_output_id=canonical_report["id"], + ) + + client = create_test_client() + response = client.get("/uk/report/2000") + + assert response.status_code == 404 + + def test_get_report_output_legacy_id_resolves_to_canonical_display_run(test_db): simulation = simulation_service.create_simulation( country_id="us", From 7cd5d3f0cfda945de09877bd5f27e9b03036e7e6 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 12 May 2026 21:29:03 +0200 Subject: [PATCH 16/17] Pin legacy report IDs to migrated runs --- ...5decda5_add_report_run_canonical_schema.py | 1 + policyengine_api/data/initialise.sql | 3 +- policyengine_api/data/initialise_local.sql | 3 +- .../services/report_output_id_map_service.py | 63 ++++++++- .../services/report_output_service.py | 52 ++++++-- tests/unit/data/test_run_schema.py | 7 +- .../test_report_output_id_map_service.py | 71 ++++++++++ .../services/test_report_output_service.py | 23 +++- tests/unit/test_stage5_routes.py | 123 ++++++++++++++---- 9 files changed, 297 insertions(+), 49 deletions(-) diff --git a/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py b/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py index 4e1e744df..065c8a53a 100644 --- a/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py +++ b/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py @@ -26,6 +26,7 @@ def upgrade() -> None: "legacy_report_output_id_map", sa.Column("legacy_report_output_id", sa.INTEGER(), nullable=False), sa.Column("canonical_report_output_id", sa.INTEGER(), nullable=False), + sa.Column("display_report_output_run_id", sa.CHAR(length=36), nullable=False), sa.PrimaryKeyConstraint("legacy_report_output_id"), ) op.create_index( diff --git a/policyengine_api/data/initialise.sql b/policyengine_api/data/initialise.sql index 8c3cf2421..51ee2710a 100644 --- a/policyengine_api/data/initialise.sql +++ b/policyengine_api/data/initialise.sql @@ -204,7 +204,8 @@ CREATE TABLE IF NOT EXISTS legacy_report_output_aliases ( CREATE TABLE IF NOT EXISTS legacy_report_output_id_map ( legacy_report_output_id INT PRIMARY KEY, - canonical_report_output_id INT NOT NULL + canonical_report_output_id INT NOT NULL, + display_report_output_run_id CHAR(36) NOT NULL ); CREATE INDEX legacy_report_output_id_map_canonical_idx diff --git a/policyengine_api/data/initialise_local.sql b/policyengine_api/data/initialise_local.sql index 1093308fc..6aae6006b 100644 --- a/policyengine_api/data/initialise_local.sql +++ b/policyengine_api/data/initialise_local.sql @@ -217,7 +217,8 @@ CREATE TABLE IF NOT EXISTS legacy_report_output_aliases ( CREATE TABLE IF NOT EXISTS legacy_report_output_id_map ( legacy_report_output_id INT PRIMARY KEY, - canonical_report_output_id INT NOT NULL + canonical_report_output_id INT NOT NULL, + display_report_output_run_id CHAR(36) NOT NULL ); CREATE INDEX legacy_report_output_id_map_canonical_idx diff --git a/policyengine_api/services/report_output_id_map_service.py b/policyengine_api/services/report_output_id_map_service.py index 507f0570a..1fc27c8fa 100644 --- a/policyengine_api/services/report_output_id_map_service.py +++ b/policyengine_api/services/report_output_id_map_service.py @@ -26,6 +26,24 @@ def _get_report_output_row( row: Row | None = queryer.query(query, tuple(params)).fetchone() return dict(row) if row is not None else None + def _get_report_output_run_row( + self, + report_output_run_id: str, + *, + canonical_report_output_id: int, + queryer=None, + ) -> dict | None: + queryer = queryer or database + row: Row | None = queryer.query( + """ + SELECT id, report_output_id + FROM report_output_runs + WHERE id = ? AND report_output_id = ? + """, + (report_output_run_id, canonical_report_output_id), + ).fetchone() + return dict(row) if row is not None else None + def _validate_mapping_identity_compatibility( self, legacy_report_output: dict, @@ -106,9 +124,22 @@ def resolve_report_output_id( and canonical_report_output["country_id"] != country_id ): return None + display_report_output_run_id = mapping["display_report_output_run_id"] + display_run = self._get_report_output_run_row( + display_report_output_run_id, + canonical_report_output_id=canonical_report_output_id, + queryer=queryer, + ) + if display_run is None: + raise ValueError( + "Legacy ID mapping points to missing display report output run " + f"#{display_report_output_run_id} for canonical report output " + f"#{canonical_report_output_id}" + ) return { "requested_report_output_id": requested_report_output_id, "canonical_report_output_id": canonical_report_output_id, + "display_report_output_run_id": display_report_output_run_id, "is_legacy_id": True, } @@ -123,6 +154,7 @@ def resolve_report_output_id( return { "requested_report_output_id": requested_report_output_id, "canonical_report_output_id": requested_report_output_id, + "display_report_output_run_id": None, "is_legacy_id": False, } @@ -146,6 +178,7 @@ def set_mapping( self, legacy_report_output_id: int, canonical_report_output_id: int, + display_report_output_run_id: str, ) -> bool: if legacy_report_output_id == canonical_report_output_id: raise ValueError("Legacy and canonical report outputs must be different") @@ -163,12 +196,15 @@ def set_mapping( if ( existing_mapping["canonical_report_output_id"] == canonical_report_output_id + and existing_mapping["display_report_output_run_id"] + == display_report_output_run_id ): return True raise ValueError( "Legacy report output ID already maps to canonical report output " - f"#{existing_mapping['canonical_report_output_id']}" + f"#{existing_mapping['canonical_report_output_id']} and display " + f"run #{existing_mapping['display_report_output_run_id']}" ) legacy_report_output = self._get_report_output_row(legacy_report_output_id) @@ -178,12 +214,31 @@ def set_mapping( canonical_report_output, ) + display_run = self._get_report_output_run_row( + display_report_output_run_id, + canonical_report_output_id=canonical_report_output_id, + ) + if display_run is None: + raise ValueError( + "Display report output run " + f"#{display_report_output_run_id} not found for canonical report " + f"#{canonical_report_output_id}" + ) + database.query( """ INSERT INTO legacy_report_output_id_map - (legacy_report_output_id, canonical_report_output_id) - VALUES (?, ?) + ( + legacy_report_output_id, + canonical_report_output_id, + display_report_output_run_id + ) + VALUES (?, ?, ?) """, - (legacy_report_output_id, canonical_report_output_id), + ( + legacy_report_output_id, + canonical_report_output_id, + display_report_output_run_id, + ), ) return True diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 44e36b4d4..b507279d8 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -1491,23 +1491,36 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No if canonical_report_output is None: return None - display_run = self.report_run_service.select_display_run( - canonical_report_output - ) - if display_run is None or ( - run_matches_report_result(display_run, canonical_report_output) - and self._run_needs_timestamp_sync( - display_run, - canonical_report_output["status"], - ) - ): - canonical_report_output = self.ensure_report_output_dual_write_state( - canonical_report_output_id, - country_id=country_id, + if resolution["is_legacy_id"]: + display_run = self._get_report_run_row( + resolution["display_report_output_run_id"], + report_output_id=canonical_report_output_id, ) + if display_run is None: + raise ValueError( + "Legacy ID mapping points to missing display report output " + f"run #{resolution['display_report_output_run_id']}" + ) + else: display_run = self.report_run_service.select_display_run( canonical_report_output ) + if display_run is None or ( + run_matches_report_result(display_run, canonical_report_output) + and self._run_needs_timestamp_sync( + display_run, + canonical_report_output["status"], + ) + ): + canonical_report_output = ( + self.ensure_report_output_dual_write_state( + canonical_report_output_id, + country_id=country_id, + ) + ) + display_run = self.report_run_service.select_display_run( + canonical_report_output + ) resolved_report_output = self._merge_display_run_into_report_output( canonical_report_output, display_run, @@ -1818,6 +1831,19 @@ def tx_callback(tx): f"#{report_output_run_id} not found for report " f"#{canonical_report_id}" ) + elif resolution["is_legacy_id"]: + mutable_run = self._get_report_run_row( + resolution["display_report_output_run_id"], + queryer=tx, + report_output_id=canonical_report_id, + for_update=True, + ) + if mutable_run is None: + raise ValueError( + "Legacy ID mapping points to missing display report " + "output run " + f"#{resolution['display_report_output_run_id']}" + ) else: mutable_run = self._select_mutable_run( canonical_report, diff --git a/tests/unit/data/test_run_schema.py b/tests/unit/data/test_run_schema.py index e6c7d3db6..6da72f5de 100644 --- a/tests/unit/data/test_run_schema.py +++ b/tests/unit/data/test_run_schema.py @@ -74,7 +74,11 @@ def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): }.issubset(simulation_run_columns) id_map_columns = _column_names(test_db, "legacy_report_output_id_map") - assert {"legacy_report_output_id", "canonical_report_output_id"} == id_map_columns + assert { + "legacy_report_output_id", + "canonical_report_output_id", + "display_report_output_run_id", + } == id_map_columns legacy_alias_columns = _column_names(test_db, "legacy_report_output_aliases") assert { @@ -100,6 +104,7 @@ def test_stage_one_schema_is_defined_in_both_sql_initializers(): "CREATE TABLE IF NOT EXISTS simulation_runs", "CREATE TABLE IF NOT EXISTS legacy_report_output_aliases", "CREATE TABLE IF NOT EXISTS legacy_report_output_id_map", + "display_report_output_run_id", "CREATE INDEX report_outputs_identity_idx", "report_spec_json", "report_spec_status", diff --git a/tests/unit/services/test_report_output_id_map_service.py b/tests/unit/services/test_report_output_id_map_service.py index ce2a26326..f1dbcc412 100644 --- a/tests/unit/services/test_report_output_id_map_service.py +++ b/tests/unit/services/test_report_output_id_map_service.py @@ -43,6 +43,19 @@ def _insert_legacy_report_output( ), ) + def _display_run_id(self, test_db, report_output_id: int) -> str: + row = test_db.query( + """ + SELECT id FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence DESC + LIMIT 1 + """, + (report_output_id,), + ).fetchone() + assert row is not None + return row["id"] + def test_resolves_to_canonical_report_output_id_when_mapping_exists(self, test_db): simulation = simulation_service.create_simulation( country_id="us", @@ -61,6 +74,9 @@ def test_resolves_to_canonical_report_output_id_when_mapping_exists(self, test_d id_map_service.set_mapping( legacy_report_output_id=999, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), ) resolved_id = id_map_service.resolve_canonical_report_output_id(999) @@ -109,6 +125,9 @@ def test_set_mapping_is_idempotent_for_same_canonical_report_output(self, test_d id_map_service.set_mapping( legacy_report_output_id=1001, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), ) is True ) @@ -116,6 +135,9 @@ def test_set_mapping_is_idempotent_for_same_canonical_report_output(self, test_d id_map_service.set_mapping( legacy_report_output_id=1001, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), ) is True ) @@ -139,10 +161,35 @@ def test_rejects_mapping_to_missing_canonical_report_output(self, test_db): id_map_service.set_mapping( legacy_report_output_id=1002, canonical_report_output_id=999999, + display_report_output_run_id="missing-run", ) assert "Canonical report output #999999 not found" in str(exc_info.value) + def test_rejects_mapping_to_missing_display_report_output_run(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_3b", + population_type="household", + policy_id=3, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + self._insert_legacy_report_output(test_db, 10020, canonical_report) + + with pytest.raises(ValueError) as exc_info: + id_map_service.set_mapping( + legacy_report_output_id=10020, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id="missing-run", + ) + + assert "Display report output run #missing-run not found" in str(exc_info.value) + def test_rejects_conflicting_mapping_remap(self, test_db): simulation = simulation_service.create_simulation( country_id="us", @@ -166,12 +213,18 @@ def test_rejects_conflicting_mapping_remap(self, test_db): id_map_service.set_mapping( legacy_report_output_id=1003, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), ) with pytest.raises(ValueError) as exc_info: id_map_service.set_mapping( legacy_report_output_id=1003, canonical_report_output_id=other_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, other_report["id"] + ), ) assert ( @@ -197,6 +250,9 @@ def test_allows_mapping_when_legacy_report_output_is_missing(self, test_db): id_map_service.set_mapping( legacy_report_output_id=10030, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), ) is True ) @@ -205,6 +261,9 @@ def test_allows_mapping_when_legacy_report_output_is_missing(self, test_db): assert resolved == { "requested_report_output_id": 10030, "canonical_report_output_id": canonical_report["id"], + "display_report_output_run_id": self._display_run_id( + test_db, canonical_report["id"] + ), "is_legacy_id": True, } @@ -270,6 +329,9 @@ def test_rejects_mapping_when_reports_do_not_share_canonical_identity( id_map_service.set_mapping( legacy_report_output_id=distinct_report["id"], canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), ) assert "must share canonical report identity" in str(exc_info.value) @@ -307,6 +369,9 @@ def test_rejects_mapping_when_legacy_report_output_has_no_identity(self, test_db id_map_service.set_mapping( legacy_report_output_id=10031, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), ) assert "must have canonical report identity" in str(exc_info.value) @@ -329,6 +394,9 @@ def test_rejects_mapping_when_legacy_and_canonical_ids_match(self, test_db): id_map_service.set_mapping( legacy_report_output_id=canonical_report["id"], canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), ) assert "must be different" in str(exc_info.value) @@ -352,6 +420,9 @@ def test_rejects_mapping_resolution_when_canonical_report_output_is_missing( id_map_service.set_mapping( legacy_report_output_id=1004, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), ) test_db.query( "DELETE FROM report_outputs WHERE id = ?", diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index c26469a78..3de2197e4 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -1397,18 +1397,37 @@ def test_get_report_output_resolves_legacy_id_to_canonical_display_run( status="complete", output=json.dumps({"budget": {"budgetary_impact": 3}}), ) + canonical_run = test_db.query( + """ + SELECT * FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence DESC + LIMIT 1 + """, + (canonical_report["id"],), + ).fetchone() + legacy_run = report_run_service.create_report_output_run( + canonical_report["id"], + status="error", + trigger_type="backfill", + output=json.dumps({"budget": {"budgetary_impact": -1}}), + error_message="legacy error", + ) id_map_service.set_mapping( legacy_report_output_id=999, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=legacy_run["id"], ) result = service.get_report_output(country_id="us", report_output_id=999) assert result is not None assert result["id"] == 999 - assert result["status"] == "complete" - assert result["output"] == json.dumps({"budget": {"budgetary_impact": 3}}) + assert result["status"] == "error" + assert result["output"] == json.dumps({"budget": {"budgetary_impact": -1}}) + assert result["error_message"] == "legacy error" assert result["api_version"] == get_report_output_cache_version("us") + assert canonical_run["id"] != legacy_run["id"] def test_get_report_output_does_not_create_current_runtime_row_for_stale_id( self, test_db diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index 42d9613c2..054f2b8fd 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -29,6 +29,20 @@ def create_test_client() -> Flask: return app.test_client() +def get_display_report_run_id(test_db, report_output_id: int) -> str: + row = test_db.query( + """ + SELECT id FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence DESC + LIMIT 1 + """, + (report_output_id,), + ).fetchone() + assert row is not None + return row["id"] + + def test_create_simulation_existing_row_repairs_dual_write_state(test_db): test_db.query( """INSERT INTO simulations @@ -797,6 +811,9 @@ def test_get_report_output_legacy_id_wrong_country_returns_not_found(test_db): report_output_id_map_service.set_mapping( legacy_report_output_id=2000, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=get_display_report_run_id( + test_db, canonical_report["id"] + ), ) client = create_test_client() @@ -805,7 +822,7 @@ def test_get_report_output_legacy_id_wrong_country_returns_not_found(test_db): assert response.status_code == 404 -def test_get_report_output_legacy_id_resolves_to_canonical_display_run(test_db): +def test_get_report_output_legacy_id_resolves_to_pinned_display_run(test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_route_alias", @@ -824,30 +841,17 @@ def test_get_report_output_legacy_id_resolves_to_canonical_display_run(test_db): status="complete", output=json.dumps({"result": "canonical"}), ) - test_db.query( - """ - INSERT INTO report_outputs ( - id, country_id, simulation_1_id, simulation_2_id, api_version, status, output, year - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - 2001, - "us", - simulation["id"], - None, - "r0legacy1", - "error", - json.dumps({"result": "legacy"}), - "2025", - ), + legacy_run = report_run_service.create_report_output_run( + canonical_report["id"], + status="error", + trigger_type="backfill", + output=json.dumps({"result": "legacy"}), + error_message="legacy failure", ) - test_db.query( - """ - INSERT INTO legacy_report_output_id_map ( - legacy_report_output_id, canonical_report_output_id - ) VALUES (?, ?) - """, - (2001, canonical_report["id"]), + report_output_id_map_service.set_mapping( + legacy_report_output_id=2001, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=legacy_run["id"], ) client = create_test_client() @@ -856,9 +860,9 @@ def test_get_report_output_legacy_id_resolves_to_canonical_display_run(test_db): assert response.status_code == 200 payload = response.get_json() assert payload["result"]["id"] == 2001 - assert payload["result"]["status"] == "complete" - assert payload["result"]["output"] == json.dumps({"result": "canonical"}) - assert payload["result"]["api_version"] == get_report_output_cache_version("us") + assert payload["result"]["status"] == "error" + assert payload["result"]["output"] == json.dumps({"result": "legacy"}) + assert payload["result"]["error_message"] == "legacy failure" def test_get_report_output_reads_malformed_legacy_row_without_runs_or_identity( @@ -1354,6 +1358,7 @@ def test_patch_report_output_explicit_run_id_through_legacy_id_updates_canonical report_output_id_map_service.set_mapping( legacy_report_output_id=3002, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=initial_run["id"], ) client = create_test_client() @@ -1384,6 +1389,67 @@ def test_patch_report_output_explicit_run_id_through_legacy_id_updates_canonical assert rerun_after["output"] == json.dumps({"result": "legacy explicit rerun"}) +def test_patch_report_output_legacy_id_defaults_to_pinned_display_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_legacy_pinned_patch", + population_type="household", + policy_id=81, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "canonical initial"}), + ) + canonical_run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (canonical_report["id"],), + ).fetchone() + legacy_run = report_run_service.create_report_output_run( + canonical_report["id"], + status="pending", + trigger_type="backfill", + ) + report_output_id_map_service.set_mapping( + legacy_report_output_id=3003, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=legacy_run["id"], + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": 3003, + "status": "complete", + "output": json.dumps({"result": "legacy patched"}), + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == 3003 + assert payload["result"]["output"] == json.dumps({"result": "legacy patched"}) + + canonical_run_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (canonical_run["id"],), + ).fetchone() + legacy_run_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (legacy_run["id"],), + ).fetchone() + assert canonical_run_after["output"] == json.dumps({"result": "canonical initial"}) + assert legacy_run_after["output"] == json.dumps({"result": "legacy patched"}) + + def test_create_report_rerun_via_canonical_id_creates_canonical_linked_runs(test_db): simulation = simulation_service.create_simulation( country_id="us", @@ -1457,6 +1523,9 @@ def test_create_report_rerun_via_legacy_id_creates_canonical_linked_runs(test_db report_output_id_map_service.set_mapping( legacy_report_output_id=3001, canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=get_display_report_run_id( + test_db, canonical_report["id"] + ), ) client = create_test_client() From a2389d2114b3c2a4437b5077c278b8d2bbbc8a50 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 13 May 2026 15:06:36 +0200 Subject: [PATCH 17/17] Align report rerun response metadata --- .../routes/report_output_routes.py | 24 ++++++++++--- .../services/report_output_service.py | 2 ++ .../services/simulation_service.py | 7 +++- tests/unit/test_stage5_routes.py | 35 ++++++++++++++++++- 4 files changed, 62 insertions(+), 6 deletions(-) diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 1b85d72ef..0354bdf1d 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -11,6 +11,15 @@ report_output_bp = Blueprint("report_output", __name__) report_output_service = ReportOutputService() +REPORT_OUTPUT_RESPONSE_INTERNAL_FIELDS = { + "active_run_id", + "latest_successful_run_id", + "report_identity_hash", + "report_identity_schema_version", + "report_spec_json", + "report_spec_schema_version", + "report_spec_status", +} RUN_METADATA_FIELDS = ( "country_package_version", "policyengine_version", @@ -32,6 +41,13 @@ def _parse_report_run_metadata(payload: dict) -> dict[str, str | None]: return metadata +def _serialize_report_output_response(report_output: dict) -> dict: + response_report = dict(report_output) + for field_name in REPORT_OUTPUT_RESPONSE_INTERNAL_FIELDS: + response_report.pop(field_name, None) + return response_report + + @report_output_bp.route("//report", methods=["POST"]) @validate_country def create_report_output(country_id: str) -> Response: @@ -112,7 +128,7 @@ def create_report_output(country_id: str) -> Response: response_body = dict( status="ok", message="Report output already exists", - result=existing_report, + result=_serialize_report_output_response(existing_report), ) return Response( @@ -134,7 +150,7 @@ def create_report_output(country_id: str) -> Response: response_body = dict( status="ok", message="Report output created successfully", - result=created_report, + result=_serialize_report_output_response(created_report), ) return Response( @@ -186,7 +202,7 @@ def get_report_output(country_id: str, report_id: int) -> Response: response_body = dict( status="ok", message=None, - result=report_output, + result=_serialize_report_output_response(report_output), ) return Response( @@ -326,7 +342,7 @@ def update_report_output(country_id: str) -> Response: response_body = dict( status="ok", message="Report output updated successfully", - result=updated_report, + result=_serialize_report_output_response(updated_report), ) return Response( diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index b507279d8..d8368dcf0 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -1322,6 +1322,7 @@ def _merge_display_run_into_report_output( result["status"] = display_run["status"] result["output"] = display_run.get("output") result["error_message"] = display_run.get("error_message") + result["display_report_output_run_id"] = display_run["id"] if display_run.get("report_cache_version") is not None: result["api_version"] = display_run["report_cache_version"] for field in ("requested_at", "started_at", "finished_at"): @@ -1692,6 +1693,7 @@ def tx_callback(tx): simulation, report_output_run_id=report_run_id, input_position=input_position, + version_manifest_overrides=version_manifest_overrides, ) simulation_run_ids.append(simulation_run["id"]) diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index 865a6e5e9..1bb93f08e 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -390,6 +390,7 @@ def create_report_rerun_simulation_run_in_transaction( *, report_output_run_id: str, input_position: int, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: simulation_spec = self._upsert_simulation_spec_in_transaction(tx, simulation) runs_descending = self._list_simulation_runs_descending( @@ -401,9 +402,13 @@ def create_report_rerun_simulation_run_in_transaction( self._build_existing_run_version_manifest( source_run, simulation, + version_manifest_overrides=version_manifest_overrides, ) if source_run is not None - else self._build_bootstrap_version_manifest(simulation) + else self._build_bootstrap_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) ) created_run = self.simulation_run_service.create_simulation_run_in_transaction( tx, diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index 054f2b8fd..50ddb845b 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -19,6 +19,15 @@ report_run_service = ReportRunService() report_output_id_map_service = ReportOutputIdMapService() simulation_run_service = SimulationRunService() +INTERNAL_REPORT_OUTPUT_RESPONSE_FIELDS = { + "active_run_id", + "latest_successful_run_id", + "report_identity_hash", + "report_identity_schema_version", + "report_spec_json", + "report_spec_schema_version", + "report_spec_status", +} def create_test_client() -> Flask: @@ -29,6 +38,10 @@ def create_test_client() -> Flask: return app.test_client() +def assert_report_output_response_hides_internal_fields(report_output: dict) -> None: + assert INTERNAL_REPORT_OUTPUT_RESPONSE_FIELDS.isdisjoint(report_output) + + def get_display_report_run_id(test_db, report_output_id: int) -> str: row = test_db.query( """ @@ -863,6 +876,8 @@ def test_get_report_output_legacy_id_resolves_to_pinned_display_run(test_db): assert payload["result"]["status"] == "error" assert payload["result"]["output"] == json.dumps({"result": "legacy"}) assert payload["result"]["error_message"] == "legacy failure" + assert payload["result"]["display_report_output_run_id"] == legacy_run["id"] + assert_report_output_response_hides_internal_fields(payload["result"]) def test_get_report_output_reads_malformed_legacy_row_without_runs_or_identity( @@ -1471,7 +1486,16 @@ def test_create_report_rerun_via_canonical_id_creates_canonical_linked_runs(test ) client = create_test_client() - response = client.post(f"/us/report/{canonical_report['id']}/rerun", json={}) + response = client.post( + f"/us/report/{canonical_report['id']}/rerun", + json={ + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + "data_version": "2026.04.16", + "runtime_app_name": "policyengine-app-v2", + "resolved_dataset": "enhanced_us_household", + }, + ) assert response.status_code == 201 result = response.get_json()["result"] @@ -1492,6 +1516,11 @@ def test_create_report_rerun_via_canonical_id_creates_canonical_linked_runs(test assert report_runs[1]["id"] == result["report_output_run_id"] assert report_runs[1]["trigger_type"] == "rerun" assert report_runs[1]["status"] == "pending" + assert report_runs[1]["country_package_version"] == "1.620.0" + assert report_runs[1]["policyengine_version"] == "0.94.2" + assert report_runs[1]["data_version"] == "2026.04.16" + assert report_runs[1]["runtime_app_name"] == "policyengine-app-v2" + assert report_runs[1]["resolved_dataset"] == "enhanced_us_household" simulation_run = test_db.query( "SELECT * FROM simulation_runs WHERE id = ?", @@ -1499,6 +1528,10 @@ def test_create_report_rerun_via_canonical_id_creates_canonical_linked_runs(test ).fetchone() assert simulation_run["report_output_run_id"] == result["report_output_run_id"] assert simulation_run["input_position"] == 1 + assert simulation_run["country_package_version"] == "1.620.0" + assert simulation_run["policyengine_version"] == "0.94.2" + assert simulation_run["data_version"] == "2026.04.16" + assert simulation_run["runtime_app_name"] == "policyengine-app-v2" def test_create_report_rerun_via_legacy_id_creates_canonical_linked_runs(test_db):