diff --git a/.claude/skills/database-migrations.md b/.claude/skills/database-migrations.md new file mode 100644 index 000000000..d94ce61dc --- /dev/null +++ b/.claude/skills/database-migrations.md @@ -0,0 +1,35 @@ +# Database Migration Guidelines + +## Overview + +This project uses Alembic for database migrations. API v1 still uses raw SQL +initializers rather than ORM models, so Alembic target metadata is reflected +from `policyengine_api/data/initialise_local.sql` by default. + +## Rules + +- Do not manually author Alembic operations for normal schema changes. +- Generate migrations with `uv run alembic revision --autogenerate`. +- Review generated migrations before applying them. +- Keep SQL initializers and generated migrations aligned. +- For pre-existing production databases, stamp the base revision before applying + new upgrade revisions. + +## Commands + +```bash +uv run alembic revision --autogenerate -m "Description" +uv run alembic upgrade head +uv run alembic current +uv run alembic history +uv run alembic stamp +``` + +## API v1 Notes + +- Set `POLICYENGINE_ALEMBIC_DATABASE_URL` to the database SQLAlchemy URL Alembic + should connect to. +- Set `POLICYENGINE_ALEMBIC_SCHEMA_SQL` when generating against a temporary + schema SQL file instead of the current initializer. +- The base migration should be stamped in production because the tables already + exist there. diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..1404bb8e6 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,48 @@ +# Alembic configuration for PolicyEngine API v1. + +[alembic] +script_location = %(here)s/alembic +file_template = %%(year)d%%(month).2d%%(day).2d_%%(rev)s_%%(slug)s +prepend_sys_path = . +path_separator = os +output_encoding = utf-8 + +# Overridden by alembic/env.py. For local generation, set +# POLICYENGINE_ALEMBIC_DATABASE_URL explicitly. +sqlalchemy.url = sqlite:///policyengine_api/data/policyengine.db + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 000000000..6d8c5892c --- /dev/null +++ b/alembic/README @@ -0,0 +1,8 @@ +PolicyEngine API v1 Alembic migrations. + +This project does not currently use SQLAlchemy ORM models. Alembic +autogenerate reflects target metadata from `policyengine_api/data/initialise_local.sql` +or from the path in `POLICYENGINE_ALEMBIC_SCHEMA_SQL`. + +Use `alembic stamp` for pre-existing production databases before applying +incremental migrations. diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 000000000..620ffe8f4 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,83 @@ +"""Alembic environment for PolicyEngine API v1 raw-SQL schema migrations.""" + +from logging.config import fileConfig +import importlib.util +import os +from pathlib import Path +import sys + +from sqlalchemy import engine_from_config, pool + +from alembic import context + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +metadata_path = ( + Path(__file__).parent.parent / "policyengine_api" / "data" / "alembic_metadata.py" +) +metadata_spec = importlib.util.spec_from_file_location( + "policyengine_api_alembic_metadata", + metadata_path, +) +if metadata_spec is None or metadata_spec.loader is None: + raise RuntimeError(f"Could not load Alembic metadata helper from {metadata_path}") +metadata_module = importlib.util.module_from_spec(metadata_spec) +metadata_spec.loader.exec_module(metadata_module) +build_metadata_from_sql = metadata_module.build_metadata_from_sql + + +config = context.config + +database_url = os.environ.get("POLICYENGINE_ALEMBIC_DATABASE_URL") or os.environ.get( + "DATABASE_URL" +) +if database_url: + config.set_main_option("sqlalchemy.url", database_url) + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +schema_sql_path = os.environ.get("POLICYENGINE_ALEMBIC_SCHEMA_SQL") +target_metadata = build_metadata_from_sql(schema_sql_path) + + +def _configure_context(connection=None, url: str | None = None) -> None: + options = { + "target_metadata": target_metadata, + "compare_type": False, + "compare_server_default": False, + } + if connection is not None: + context.configure(connection=connection, **options) + else: + context.configure( + url=url, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + **options, + ) + + +def run_migrations_offline() -> None: + _configure_context(url=config.get_main_option("sqlalchemy.url")) + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + _configure_context(connection=connection) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 000000000..11016301e --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py b/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py new file mode 100644 index 000000000..065c8a53a --- /dev/null +++ b/alembic/versions/20260511_558935decda5_add_report_run_canonical_schema.py @@ -0,0 +1,73 @@ +"""Add report run canonical schema + +Revision ID: 558935decda5 +Revises: 60d38593ddc3 +Create Date: 2026-05-11 22:21:20.417733 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "558935decda5" +down_revision: Union[str, Sequence[str], None] = "60d38593ddc3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "legacy_report_output_id_map", + sa.Column("legacy_report_output_id", sa.INTEGER(), nullable=False), + sa.Column("canonical_report_output_id", sa.INTEGER(), nullable=False), + sa.Column("display_report_output_run_id", sa.CHAR(length=36), nullable=False), + sa.PrimaryKeyConstraint("legacy_report_output_id"), + ) + op.create_index( + "legacy_report_output_id_map_canonical_idx", + "legacy_report_output_id_map", + ["canonical_report_output_id"], + unique=False, + ) + op.add_column( + "report_outputs", + sa.Column("report_identity_hash", sa.VARCHAR(length=64), nullable=True), + ) + op.add_column( + "report_outputs", + sa.Column("report_identity_schema_version", sa.INTEGER(), nullable=True), + ) + op.create_index( + "report_outputs_identity_idx", + "report_outputs", + ["country_id", "report_identity_hash", "report_identity_schema_version"], + unique=False, + ) + op.create_index( + "simulation_runs_report_output_run_idx", + "simulation_runs", + ["report_output_run_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("simulation_runs_report_output_run_idx", table_name="simulation_runs") + op.drop_index("report_outputs_identity_idx", table_name="report_outputs") + op.drop_column("report_outputs", "report_identity_schema_version") + op.drop_column("report_outputs", "report_identity_hash") + op.drop_index( + "legacy_report_output_id_map_canonical_idx", + table_name="legacy_report_output_id_map", + ) + op.drop_table("legacy_report_output_id_map") + # ### end Alembic commands ### diff --git a/alembic/versions/20260511_60d38593ddc3_initial_legacy_schema.py b/alembic/versions/20260511_60d38593ddc3_initial_legacy_schema.py new file mode 100644 index 000000000..39bd19baf --- /dev/null +++ b/alembic/versions/20260511_60d38593ddc3_initial_legacy_schema.py @@ -0,0 +1,264 @@ +"""Initial legacy schema + +Revision ID: 60d38593ddc3 +Revises: +Create Date: 2026-05-11 20:19:44.056995 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "60d38593ddc3" +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "analysis", + sa.Column("prompt_id", sa.INTEGER(), nullable=False), + sa.Column("prompt", sa.TEXT(), nullable=False), + sa.Column("analysis", sa.TEXT(), nullable=True), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.PrimaryKeyConstraint("prompt_id"), + ) + op.create_table( + "computed_household", + sa.Column("household_id", sa.INTEGER(), nullable=False), + sa.Column("policy_id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("computed_household_json", sa.JSON(), nullable=False), + sa.Column("status", sa.VARCHAR(length=32), nullable=True), + sa.PrimaryKeyConstraint("household_id", "policy_id", "country_id"), + ) + op.create_table( + "economy", + sa.Column("economy_id", sa.INTEGER(), nullable=False), + sa.Column("policy_id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("region", sa.VARCHAR(length=32), nullable=True), + sa.Column("time_period", sa.VARCHAR(length=32), nullable=True), + sa.Column("options_json", sa.JSON(), nullable=False), + sa.Column("options_hash", sa.VARCHAR(length=255), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("economy_json", sa.JSON(), nullable=True), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.Column("message", sa.VARCHAR(length=255), nullable=True), + sa.PrimaryKeyConstraint("economy_id"), + ) + op.create_table( + "household", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("label", sa.VARCHAR(length=255), nullable=True), + sa.Column("api_version", sa.VARCHAR(length=255), nullable=False), + sa.Column("household_json", sa.JSON(), nullable=False), + sa.Column("household_hash", sa.VARCHAR(length=255), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "legacy_report_output_aliases", + sa.Column("legacy_report_output_id", sa.INTEGER(), nullable=False), + sa.Column("canonical_report_output_id", sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint("legacy_report_output_id"), + ) + op.create_table( + "policy", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("label", sa.VARCHAR(length=255), nullable=True), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("policy_json", sa.JSON(), nullable=False), + sa.Column("policy_hash", sa.VARCHAR(length=255), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "reform_impact", + sa.Column("reform_impact_id", sa.INTEGER(), nullable=False), + sa.Column("baseline_policy_id", sa.INTEGER(), nullable=False), + sa.Column("reform_policy_id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("region", sa.VARCHAR(length=32), nullable=False), + sa.Column("dataset", sa.VARCHAR(length=255), nullable=False), + sa.Column("time_period", sa.VARCHAR(length=32), nullable=False), + sa.Column("options_json", sa.JSON(), nullable=False), + sa.Column("options_hash", sa.VARCHAR(length=255), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("reform_impact_json", sa.JSON(), nullable=False), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.Column("message", sa.VARCHAR(length=255), nullable=True), + sa.Column("start_time", sa.DATETIME(), nullable=False), + sa.Column("end_time", sa.DATETIME(), nullable=True), + sa.Column("execution_id", sa.VARCHAR(length=255), nullable=False), + sa.PrimaryKeyConstraint("reform_impact_id"), + ) + op.create_table( + "report_output_runs", + sa.Column("id", sa.CHAR(length=36), nullable=False), + sa.Column("report_output_id", sa.INTEGER(), nullable=False), + sa.Column("run_sequence", sa.INTEGER(), nullable=False), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.Column("output", sa.JSON(), nullable=True), + sa.Column("error_message", sa.TEXT(), nullable=True), + sa.Column("trigger_type", sa.VARCHAR(length=32), nullable=False), + sa.Column("requested_at", sa.DATETIME(), nullable=True), + sa.Column("started_at", sa.DATETIME(), nullable=True), + sa.Column("finished_at", sa.DATETIME(), nullable=True), + sa.Column("source_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("report_spec_snapshot_json", sa.JSON(), nullable=True), + sa.Column("country_package_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("policyengine_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("data_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("runtime_app_name", sa.VARCHAR(length=255), nullable=True), + sa.Column("report_cache_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("simulation_cache_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("requested_version_override", sa.VARCHAR(length=255), nullable=True), + sa.Column("resolved_dataset", sa.VARCHAR(length=255), nullable=True), + sa.Column("resolved_options_hash", sa.VARCHAR(length=255), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("report_output_id", "run_sequence"), + ) + op.create_table( + "report_outputs", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("simulation_1_id", sa.INTEGER(), nullable=False), + sa.Column("simulation_2_id", sa.INTEGER(), nullable=True), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column( + "status", + sa.VARCHAR(length=32), + server_default=sa.text("'pending'"), + nullable=False, + ), + sa.Column("output", sa.JSON(), nullable=True), + sa.Column("error_message", sa.TEXT(), nullable=True), + sa.Column( + "year", + sa.VARCHAR(length=255), + server_default=sa.text("'2025'"), + nullable=True, + ), + sa.Column("report_kind", sa.VARCHAR(length=64), nullable=True), + sa.Column("report_spec_json", sa.JSON(), nullable=True), + sa.Column("report_spec_schema_version", sa.INTEGER(), nullable=True), + sa.Column("report_spec_status", sa.VARCHAR(length=32), nullable=True), + sa.Column("active_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("latest_successful_run_id", sa.CHAR(length=36), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "simulation_runs", + sa.Column("id", sa.CHAR(length=36), nullable=False), + sa.Column("simulation_id", sa.INTEGER(), nullable=False), + sa.Column("report_output_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("input_position", sa.INTEGER(), nullable=True), + sa.Column("run_sequence", sa.INTEGER(), nullable=False), + sa.Column("status", sa.VARCHAR(length=32), nullable=False), + sa.Column("output", sa.JSON(), nullable=True), + sa.Column("error_message", sa.TEXT(), nullable=True), + sa.Column("trigger_type", sa.VARCHAR(length=32), nullable=False), + sa.Column("requested_at", sa.DATETIME(), nullable=True), + sa.Column("started_at", sa.DATETIME(), nullable=True), + sa.Column("finished_at", sa.DATETIME(), nullable=True), + sa.Column("source_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("simulation_spec_snapshot_json", sa.JSON(), nullable=True), + sa.Column("country_package_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("policyengine_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("data_version", sa.VARCHAR(length=255), nullable=True), + sa.Column("runtime_app_name", sa.VARCHAR(length=255), nullable=True), + sa.Column("simulation_cache_version", sa.VARCHAR(length=255), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("simulation_id", "run_sequence"), + ) + op.create_table( + "simulations", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("population_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("population_type", sa.VARCHAR(length=50), nullable=False), + sa.Column("policy_id", sa.INTEGER(), nullable=False), + sa.Column( + "status", + sa.VARCHAR(length=32), + server_default=sa.text("'pending'"), + nullable=False, + ), + sa.Column("output", sa.JSON(), nullable=True), + sa.Column("error_message", sa.TEXT(), nullable=True), + sa.Column("simulation_spec_json", sa.JSON(), nullable=True), + sa.Column("simulation_spec_schema_version", sa.INTEGER(), nullable=True), + sa.Column("active_run_id", sa.CHAR(length=36), nullable=True), + sa.Column("latest_successful_run_id", sa.CHAR(length=36), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "tracers", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("household_id", sa.INTEGER(), nullable=False), + sa.Column("policy_id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=10), nullable=False), + sa.Column("tracer_output", sa.JSON(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "user_policies", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("country_id", sa.VARCHAR(length=3), nullable=False), + sa.Column("reform_id", sa.INTEGER(), nullable=False), + sa.Column("reform_label", sa.VARCHAR(length=255), nullable=True), + sa.Column("baseline_id", sa.INTEGER(), nullable=False), + sa.Column("baseline_label", sa.VARCHAR(length=255), nullable=True), + sa.Column("user_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("year", sa.VARCHAR(length=32), nullable=False), + sa.Column("geography", sa.VARCHAR(length=255), nullable=False), + sa.Column("dataset", sa.VARCHAR(length=255), nullable=True), + sa.Column("number_of_provisions", sa.INTEGER(), nullable=False), + sa.Column("api_version", sa.VARCHAR(length=32), nullable=False), + sa.Column("added_date", sa.BIGINT(), nullable=False), + sa.Column("updated_date", sa.BIGINT(), nullable=False), + sa.Column("budgetary_impact", sa.VARCHAR(length=255), nullable=True), + sa.Column("type", sa.VARCHAR(length=255), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "user_profiles", + sa.Column("user_id", sa.INTEGER(), nullable=False), + sa.Column("auth0_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("username", sa.VARCHAR(length=255), nullable=True), + sa.Column("primary_country", sa.VARCHAR(length=3), nullable=False), + sa.Column("user_since", sa.BIGINT(), nullable=False), + sa.PrimaryKeyConstraint("user_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("user_profiles") + op.drop_table("user_policies") + op.drop_table("tracers") + op.drop_table("simulations") + op.drop_table("simulation_runs") + op.drop_table("report_outputs") + op.drop_table("report_output_runs") + op.drop_table("reform_impact") + op.drop_table("policy") + op.drop_table("legacy_report_output_aliases") + op.drop_table("household") + op.drop_table("economy") + op.drop_table("computed_household") + op.drop_table("analysis") + # ### end Alembic commands ### diff --git a/changelog.d/3500.changed.md b/changelog.d/3500.changed.md new file mode 100644 index 000000000..e54fc3e08 --- /dev/null +++ b/changelog.d/3500.changed.md @@ -0,0 +1 @@ +Preserve explicit report definitions and execution metadata, key report creation by canonical report identity, resolve legacy report IDs through a permanent compatibility map, and add run-targeted report and simulation rerun updates. diff --git a/policyengine_api/data/alembic_metadata.py b/policyengine_api/data/alembic_metadata.py new file mode 100644 index 000000000..580e79867 --- /dev/null +++ b/policyengine_api/data/alembic_metadata.py @@ -0,0 +1,56 @@ +"""Build Alembic target metadata from the existing SQL initializer.""" + +from pathlib import Path +import sqlite3 + +from sqlalchemy import JSON, MetaData, create_engine + + +DEFAULT_SCHEMA_SQL = Path(__file__).with_name("initialise_local.sql") + +JSON_COLUMN_NAMES = { + "computed_household_json", + "economy_json", + "household_json", + "options_json", + "policy_json", + "reform_impact_json", + "report_spec_json", + "report_spec_snapshot_json", + "simulation_spec_json", + "simulation_spec_snapshot_json", + "tracer_output", + "output", +} + + +def _normalize_reflected_metadata(metadata: MetaData) -> None: + for table in metadata.tables.values(): + for column in table.columns: + if column.name in JSON_COLUMN_NAMES: + column.type = JSON() + if column.primary_key: + column.nullable = False + if column.server_default is not None: + default_arg = str(column.server_default.arg).strip().upper() + if "NULL" in default_arg: + column.server_default = None + + +def build_metadata_from_sql(schema_sql_path: str | Path | None = None) -> MetaData: + """Reflect SQL initializer DDL into SQLAlchemy metadata for autogenerate. + + API v1 still uses raw SQL rather than ORM models. This keeps Alembic's + autogenerate path tied to the existing initializer instead of maintaining a + second manually-authored schema definition. + """ + + schema_sql_path = Path(schema_sql_path or DEFAULT_SCHEMA_SQL) + connection = sqlite3.connect(":memory:") + connection.executescript(schema_sql_path.read_text()) + + engine = create_engine("sqlite://", creator=lambda: connection) + metadata = MetaData() + metadata.reflect(bind=engine) + _normalize_reflected_metadata(metadata) + return metadata diff --git a/policyengine_api/data/initialise.sql b/policyengine_api/data/initialise.sql index 085f31c0b..51ee2710a 100644 --- a/policyengine_api/data/initialise.sql +++ b/policyengine_api/data/initialise.sql @@ -135,10 +135,17 @@ CREATE TABLE IF NOT EXISTS report_outputs ( report_spec_json JSON DEFAULT NULL, report_spec_schema_version INT DEFAULT NULL, report_spec_status VARCHAR(32) DEFAULT NULL, + report_identity_hash VARCHAR(64) DEFAULT NULL, + report_identity_schema_version INT DEFAULT NULL, active_run_id CHAR(36) DEFAULT NULL, latest_successful_run_id CHAR(36) DEFAULT NULL ); +CREATE INDEX report_outputs_identity_idx + ON report_outputs ( + country_id, report_identity_hash, report_identity_schema_version + ); + CREATE TABLE IF NOT EXISTS report_output_runs ( id CHAR(36) PRIMARY KEY, report_output_id INT NOT NULL, @@ -187,7 +194,19 @@ CREATE TABLE IF NOT EXISTS simulation_runs ( UNIQUE KEY simulation_run_sequence_idx (simulation_id, run_sequence) ); +CREATE INDEX simulation_runs_report_output_run_idx + ON simulation_runs (report_output_run_id); + CREATE TABLE IF NOT EXISTS legacy_report_output_aliases ( legacy_report_output_id INT PRIMARY KEY, canonical_report_output_id INT NOT NULL ); + +CREATE TABLE IF NOT EXISTS legacy_report_output_id_map ( + legacy_report_output_id INT PRIMARY KEY, + canonical_report_output_id INT NOT NULL, + display_report_output_run_id CHAR(36) NOT NULL +); + +CREATE INDEX legacy_report_output_id_map_canonical_idx + ON legacy_report_output_id_map (canonical_report_output_id); diff --git a/policyengine_api/data/initialise_local.sql b/policyengine_api/data/initialise_local.sql index 53a37b4c8..6aae6006b 100644 --- a/policyengine_api/data/initialise_local.sql +++ b/policyengine_api/data/initialise_local.sql @@ -9,6 +9,7 @@ DROP TABLE IF EXISTS tracers; DROP TABLE IF EXISTS report_output_runs; DROP TABLE IF EXISTS simulation_runs; DROP TABLE IF EXISTS legacy_report_output_aliases; +DROP TABLE IF EXISTS legacy_report_output_id_map; CREATE TABLE IF NOT EXISTS household ( id INTEGER PRIMARY KEY, @@ -147,10 +148,17 @@ CREATE TABLE IF NOT EXISTS report_outputs ( report_spec_json JSON DEFAULT NULL, report_spec_schema_version INT DEFAULT NULL, report_spec_status VARCHAR(32) DEFAULT NULL, + report_identity_hash VARCHAR(64) DEFAULT NULL, + report_identity_schema_version INT DEFAULT NULL, active_run_id CHAR(36) DEFAULT NULL, latest_successful_run_id CHAR(36) DEFAULT NULL ); +CREATE INDEX report_outputs_identity_idx + ON report_outputs ( + country_id, report_identity_hash, report_identity_schema_version + ); + CREATE TABLE IF NOT EXISTS report_output_runs ( id CHAR(36) PRIMARY KEY, report_output_id INT NOT NULL, @@ -199,7 +207,19 @@ CREATE TABLE IF NOT EXISTS simulation_runs ( UNIQUE (simulation_id, run_sequence) ); +CREATE INDEX simulation_runs_report_output_run_idx + ON simulation_runs (report_output_run_id); + CREATE TABLE IF NOT EXISTS legacy_report_output_aliases ( legacy_report_output_id INT PRIMARY KEY, canonical_report_output_id INT NOT NULL ); + +CREATE TABLE IF NOT EXISTS legacy_report_output_id_map ( + legacy_report_output_id INT PRIMARY KEY, + canonical_report_output_id INT NOT NULL, + display_report_output_run_id CHAR(36) NOT NULL +); + +CREATE INDEX legacy_report_output_id_map_canonical_idx + ON legacy_report_output_id_map (canonical_report_output_id); diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 48a2ac43a..1b85d72ef 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -11,6 +11,25 @@ report_output_bp = Blueprint("report_output", __name__) report_output_service = ReportOutputService() +RUN_METADATA_FIELDS = ( + "country_package_version", + "policyengine_version", + "data_version", + "runtime_app_name", + "resolved_dataset", +) + + +def _parse_report_run_metadata(payload: dict) -> dict[str, str | None]: + metadata: dict[str, str | None] = {} + for field_name in RUN_METADATA_FIELDS: + if field_name not in payload: + continue + value = payload.get(field_name) + if value is not None and not isinstance(value, str): + raise BadRequest(f"{field_name} must be a string or null") + metadata[field_name] = value + return metadata @report_output_bp.route("//report", methods=["POST"]) @@ -36,6 +55,8 @@ def create_report_output(country_id: str) -> Response: simulation_1_id = payload.get("simulation_1_id") simulation_2_id = payload.get("simulation_2_id") # Optional year = payload.get("year", CURRENT_YEAR) # Default to current year as string + report_spec_payload = payload.get("report_spec") + report_spec_schema_version = payload.get("report_spec_schema_version") # Validate required fields if simulation_1_id is None: @@ -46,14 +67,35 @@ def create_report_output(country_id: str) -> Response: raise BadRequest("simulation_2_id must be an integer or null") if not isinstance(year, str): raise BadRequest("year must be a string") + if report_spec_payload is not None and not isinstance(report_spec_payload, dict): + raise BadRequest("report_spec must be an object") + if report_spec_schema_version is not None and not isinstance( + report_spec_schema_version, int + ): + raise BadRequest("report_spec_schema_version must be an integer") + + report_spec = None + if report_spec_payload is not None: + try: + report_spec = report_output_service.parse_report_spec_payload( + report_spec_payload, + ( + report_spec_schema_version + if report_spec_schema_version is not None + else 1 + ), + ) + except ValueError as exc: + raise BadRequest(str(exc)) from exc try: # Check if report already exists with these simulation IDs and year - existing_report = report_output_service.find_existing_report_output( + existing_report = report_output_service.find_existing_report_output_for_create( country_id=country_id, simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, ) if existing_report: @@ -61,6 +103,9 @@ def create_report_output(country_id: str) -> Response: report_output_service.ensure_report_output_dual_write_state( existing_report["id"], country_id=country_id, + explicit_report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, + ensure_current_report_cache_run=True, ) ) # Report already exists, return it with 200 status @@ -82,6 +127,8 @@ def create_report_output(country_id: str) -> Response: simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) response_body = dict( @@ -149,6 +196,53 @@ def get_report_output(country_id: str, report_id: int) -> Response: ) +@report_output_bp.route("//report//rerun", methods=["POST"]) +@validate_country +def create_report_rerun(country_id: str, report_id: int) -> Response: + """ + Create a new pending run for an existing report. + + The requested report ID may be a legacy ID; the run is always created under + the resolved canonical report output. + """ + payload = request.json or {} + if not isinstance(payload, dict): + raise BadRequest("Payload must be an object") + + version_manifest_overrides = _parse_report_run_metadata(payload) + + try: + if not report_output_service.report_output_exists(country_id, report_id): + raise NotFound(f"Report #{report_id} not found.") + + rerun = report_output_service.create_report_rerun( + country_id=country_id, + report_output_id=report_id, + version_manifest_overrides=version_manifest_overrides, + ) + except HTTPException: + raise + except ValueError as e: + current_app.logger.warning( + "Bad request creating report rerun #%s for country %s: %s", + report_id, + country_id, + e, + ) + raise BadRequest(f"Failed to create report rerun: {e}") from e + + response_body = dict( + status="ok", + message="Report rerun created successfully", + result=rerun, + ) + return Response( + json.dumps(response_body), + status=201, + mimetype="application/json", + ) + + @report_output_bp.route("//report", methods=["PATCH"]) @validate_country def update_report_output(country_id: str) -> Response: @@ -160,6 +254,7 @@ def update_report_output(country_id: str) -> Response: Request body can contain: - id (int): The report output ID. + - report_output_run_id (str | None): Specific report run to update. - status (str): The new status ('pending', 'running', 'complete', or 'error') - output (dict): The result output (for complete status) - api_version (str): The API version of the report @@ -175,6 +270,8 @@ def update_report_output(country_id: str) -> Response: report_id = payload.get("id") output = payload.get("output") error_message = payload.get("error_message") + report_output_run_id = payload.get("report_output_run_id") + version_manifest_overrides = _parse_report_run_metadata(payload) print(f"Updating report #{report_id} for country {country_id}") # Validate status if provided @@ -189,6 +286,8 @@ def update_report_output(country_id: str) -> Response: # Validate that complete status has output if status == "complete" and output is None: raise BadRequest("output is required when status is 'complete'") + if report_output_run_id is not None and not isinstance(report_output_run_id, str): + raise BadRequest("report_output_run_id must be a string") try: # First check if the report output exists without running pointer sync: @@ -204,16 +303,25 @@ def update_report_output(country_id: str) -> Response: status=status, output=output, error_message=error_message, + report_output_run_id=report_output_run_id, + version_manifest_overrides=version_manifest_overrides, ) if not success: raise BadRequest("No fields to update") - # Get the updated stored record so stale-runtime jobs do not appear to - # complete the current runtime lineage in the PATCH response. - updated_report = report_output_service.get_stored_report_output( - country_id, report_id - ) + if report_output_run_id is not None: + updated_report = report_output_service.get_report_output_for_run( + country_id, + report_id, + report_output_run_id, + ) + else: + updated_report = report_output_service.get_report_output( + country_id, report_id + ) + if updated_report is None: + raise NotFound(f"Report #{report_id} not found.") response_body = dict( status="ok", diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index f2bacd6cb..375e1d194 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -10,6 +10,24 @@ simulation_bp = Blueprint("simulation", __name__) simulation_service = SimulationService() +RUN_METADATA_FIELDS = ( + "country_package_version", + "policyengine_version", + "data_version", + "runtime_app_name", +) + + +def _parse_simulation_run_metadata(payload: dict) -> dict[str, str | None]: + metadata: dict[str, str | None] = {} + for field_name in RUN_METADATA_FIELDS: + if field_name not in payload: + continue + value = payload.get(field_name) + if value is not None and not isinstance(value, str): + raise BadRequest(f"{field_name} must be a string or null") + metadata[field_name] = value + return metadata @simulation_bp.route("//simulation", methods=["POST"]) @@ -160,6 +178,7 @@ def update_simulation(country_id: str) -> Response: Request body can contain: - id (int): The simulation ID. + - simulation_run_id (str | None): Specific simulation run to update. - status (str): The new status ('complete' or 'error') - output (dict): The result output (for complete status) - api_version (str): The API version of the simulation @@ -175,6 +194,8 @@ def update_simulation(country_id: str) -> Response: simulation_id = payload.get("id") output = payload.get("output") error_message = payload.get("error_message") + simulation_run_id = payload.get("simulation_run_id") + version_manifest_overrides = _parse_simulation_run_metadata(payload) print(f"Updating simulation #{simulation_id} for country {country_id}") # Validate status if provided @@ -184,6 +205,8 @@ def update_simulation(country_id: str) -> Response: # Validate that complete status has output if status == "complete" and output is None: raise BadRequest("output is required when status is 'complete'") + if simulation_run_id is not None and not isinstance(simulation_run_id, str): + raise BadRequest("simulation_run_id must be a string") try: # First check if the simulation exists @@ -200,15 +223,26 @@ def update_simulation(country_id: str) -> Response: status=status, output=output, error_message=error_message, + simulation_run_id=simulation_run_id, + version_manifest_overrides=version_manifest_overrides, ) if not success: raise BadRequest("No fields to update") - # Get the updated record - updated_simulation = simulation_service.get_simulation( - country_id, simulation_id - ) + if simulation_run_id is not None: + updated_simulation = simulation_service.get_simulation_for_run( + country_id, + simulation_id, + simulation_run_id, + ) + else: + updated_simulation = simulation_service.get_simulation( + country_id, + simulation_id, + ) + if updated_simulation is None: + raise NotFound(f"Simulation #{simulation_id} not found.") response_body = dict( status="ok", diff --git a/policyengine_api/services/report_output_alias_service.py b/policyengine_api/services/report_output_alias_service.py deleted file mode 100644 index 9440cfdfd..000000000 --- a/policyengine_api/services/report_output_alias_service.py +++ /dev/null @@ -1,97 +0,0 @@ -from sqlalchemy.engine.row import Row - -from policyengine_api.data import database - - -class ReportOutputAliasService: - def _get_report_output_row(self, report_output_id: int) -> dict | None: - row: Row | None = database.query( - """ - SELECT id, country_id, simulation_1_id, simulation_2_id, year - FROM report_outputs - WHERE id = ? - """, - (report_output_id,), - ).fetchone() - return dict(row) if row is not None else None - - def get_alias(self, legacy_report_output_id: int) -> dict | None: - row: Row | None = database.query( - """ - SELECT * FROM legacy_report_output_aliases - WHERE legacy_report_output_id = ? - """, - (legacy_report_output_id,), - ).fetchone() - return dict(row) if row is not None else None - - def resolve_canonical_report_output_id( - self, requested_report_output_id: int - ) -> int | None: - alias = self.get_alias(requested_report_output_id) - if alias is not None: - canonical_report_output_id = alias["canonical_report_output_id"] - if self._get_report_output_row(canonical_report_output_id) is None: - raise ValueError( - "Alias points to missing canonical report output " - f"#{canonical_report_output_id}" - ) - return canonical_report_output_id - - row: Row | None = database.query( - "SELECT id FROM report_outputs WHERE id = ?", - (requested_report_output_id,), - ).fetchone() - return row["id"] if row is not None else None - - def set_alias( - self, - legacy_report_output_id: int, - canonical_report_output_id: int, - ) -> bool: - legacy_report_output = self._get_report_output_row(legacy_report_output_id) - if legacy_report_output is None: - raise ValueError( - f"Legacy report output #{legacy_report_output_id} not found" - ) - - canonical_report_output = self._get_report_output_row( - canonical_report_output_id - ) - if canonical_report_output is None: - raise ValueError( - f"Canonical report output #{canonical_report_output_id} not found" - ) - if legacy_report_output_id == canonical_report_output_id: - raise ValueError("Legacy and canonical report outputs must be different") - - existing_alias = self.get_alias(legacy_report_output_id) - if existing_alias is not None: - if ( - existing_alias["canonical_report_output_id"] - == canonical_report_output_id - ): - return True - - raise ValueError( - "Legacy report output alias already points to canonical report output " - f"#{existing_alias['canonical_report_output_id']}" - ) - - logical_key = ("country_id", "simulation_1_id", "simulation_2_id", "year") - if any( - legacy_report_output[field] != canonical_report_output[field] - for field in logical_key - ): - raise ValueError( - "Legacy and canonical report outputs must describe the same report" - ) - database.query( - """ - INSERT INTO legacy_report_output_aliases - (legacy_report_output_id, canonical_report_output_id) - VALUES (?, ?) - """, - (legacy_report_output_id, canonical_report_output_id), - ) - return True diff --git a/policyengine_api/services/report_output_id_map_service.py b/policyengine_api/services/report_output_id_map_service.py new file mode 100644 index 000000000..1fc27c8fa --- /dev/null +++ b/policyengine_api/services/report_output_id_map_service.py @@ -0,0 +1,244 @@ +from sqlalchemy.engine.row import Row + +from policyengine_api.data import database + + +class ReportOutputIdMapService: + def _get_report_output_row( + self, + report_output_id: int, + *, + queryer=None, + country_id: str | None = None, + ) -> dict | None: + queryer = queryer or database + query = """ + SELECT id, country_id, report_identity_hash, + report_identity_schema_version + FROM report_outputs + WHERE id = ? + """ + params: list[int | str] = [report_output_id] + if country_id is not None: + query += " AND country_id = ?" + params.append(country_id) + + row: Row | None = queryer.query(query, tuple(params)).fetchone() + return dict(row) if row is not None else None + + def _get_report_output_run_row( + self, + report_output_run_id: str, + *, + canonical_report_output_id: int, + queryer=None, + ) -> dict | None: + queryer = queryer or database + row: Row | None = queryer.query( + """ + SELECT id, report_output_id + FROM report_output_runs + WHERE id = ? AND report_output_id = ? + """, + (report_output_run_id, canonical_report_output_id), + ).fetchone() + return dict(row) if row is not None else None + + def _validate_mapping_identity_compatibility( + self, + legacy_report_output: dict, + canonical_report_output: dict, + ) -> None: + if legacy_report_output["country_id"] != canonical_report_output["country_id"]: + raise ValueError( + "Legacy and canonical report outputs must describe the same report" + ) + + if ( + legacy_report_output["report_identity_hash"] is None + or legacy_report_output["report_identity_schema_version"] is None + ): + raise ValueError( + "Legacy report output must have canonical report identity before " + "mapping" + ) + + if ( + canonical_report_output["report_identity_hash"] is None + or canonical_report_output["report_identity_schema_version"] is None + ): + raise ValueError( + "Canonical report output must have canonical report identity before " + "mapping" + ) + + if ( + legacy_report_output["report_identity_hash"] + != canonical_report_output["report_identity_hash"] + or legacy_report_output["report_identity_schema_version"] + != canonical_report_output["report_identity_schema_version"] + ): + raise ValueError( + "Legacy and canonical report outputs must share canonical report " + "identity" + ) + + def get_mapping( + self, + legacy_report_output_id: int, + *, + queryer=None, + ) -> dict | None: + queryer = queryer or database + row: Row | None = queryer.query( + """ + SELECT * FROM legacy_report_output_id_map + WHERE legacy_report_output_id = ? + """, + (legacy_report_output_id,), + ).fetchone() + return dict(row) if row is not None else None + + def resolve_report_output_id( + self, + requested_report_output_id: int, + *, + queryer=None, + country_id: str | None = None, + ) -> dict | None: + queryer = queryer or database + mapping = self.get_mapping(requested_report_output_id, queryer=queryer) + if mapping is not None: + canonical_report_output_id = mapping["canonical_report_output_id"] + canonical_report_output = self._get_report_output_row( + canonical_report_output_id, + queryer=queryer, + ) + if canonical_report_output is None: + raise ValueError( + "Legacy ID mapping points to missing canonical report output " + f"#{canonical_report_output_id}" + ) + if ( + country_id is not None + and canonical_report_output["country_id"] != country_id + ): + return None + display_report_output_run_id = mapping["display_report_output_run_id"] + display_run = self._get_report_output_run_row( + display_report_output_run_id, + canonical_report_output_id=canonical_report_output_id, + queryer=queryer, + ) + if display_run is None: + raise ValueError( + "Legacy ID mapping points to missing display report output run " + f"#{display_report_output_run_id} for canonical report output " + f"#{canonical_report_output_id}" + ) + return { + "requested_report_output_id": requested_report_output_id, + "canonical_report_output_id": canonical_report_output_id, + "display_report_output_run_id": display_report_output_run_id, + "is_legacy_id": True, + } + + requested_report_output = self._get_report_output_row( + requested_report_output_id, + queryer=queryer, + country_id=country_id, + ) + if requested_report_output is None: + return None + + return { + "requested_report_output_id": requested_report_output_id, + "canonical_report_output_id": requested_report_output_id, + "display_report_output_run_id": None, + "is_legacy_id": False, + } + + def resolve_canonical_report_output_id( + self, + requested_report_output_id: int, + *, + queryer=None, + country_id: str | None = None, + ) -> int | None: + resolution = self.resolve_report_output_id( + requested_report_output_id, + queryer=queryer, + country_id=country_id, + ) + if resolution is None: + return None + return resolution["canonical_report_output_id"] + + def set_mapping( + self, + legacy_report_output_id: int, + canonical_report_output_id: int, + display_report_output_run_id: str, + ) -> bool: + if legacy_report_output_id == canonical_report_output_id: + raise ValueError("Legacy and canonical report outputs must be different") + + canonical_report_output = self._get_report_output_row( + canonical_report_output_id + ) + if canonical_report_output is None: + raise ValueError( + f"Canonical report output #{canonical_report_output_id} not found" + ) + + existing_mapping = self.get_mapping(legacy_report_output_id) + if existing_mapping is not None: + if ( + existing_mapping["canonical_report_output_id"] + == canonical_report_output_id + and existing_mapping["display_report_output_run_id"] + == display_report_output_run_id + ): + return True + + raise ValueError( + "Legacy report output ID already maps to canonical report output " + f"#{existing_mapping['canonical_report_output_id']} and display " + f"run #{existing_mapping['display_report_output_run_id']}" + ) + + legacy_report_output = self._get_report_output_row(legacy_report_output_id) + if legacy_report_output is not None: + self._validate_mapping_identity_compatibility( + legacy_report_output, + canonical_report_output, + ) + + display_run = self._get_report_output_run_row( + display_report_output_run_id, + canonical_report_output_id=canonical_report_output_id, + ) + if display_run is None: + raise ValueError( + "Display report output run " + f"#{display_report_output_run_id} not found for canonical report " + f"#{canonical_report_output_id}" + ) + + database.query( + """ + INSERT INTO legacy_report_output_id_map + ( + legacy_report_output_id, + canonical_report_output_id, + display_report_output_run_id + ) + VALUES (?, ?, ?) + """, + ( + legacy_report_output_id, + canonical_report_output_id, + display_report_output_run_id, + ), + ) + return True diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 38b5704fa..b507279d8 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -1,13 +1,17 @@ -import uuid from datetime import datetime, timezone from sqlalchemy.engine.row import Row from policyengine_api.constants import get_report_output_cache_version from policyengine_api.data import database +from policyengine_api.services.report_output_id_map_service import ( + ReportOutputIdMapService, +) +from policyengine_api.services.report_run_service import ReportRunService from policyengine_api.services.report_spec_service import ( ECONOMY_REPORT_KINDS, ReportSpec, + REPORT_SPEC_SCHEMA_VERSION, ReportSpecService, ) from policyengine_api.services.run_sync_utils import ( @@ -24,6 +28,8 @@ class ReportOutputService: def __init__(self): self.report_spec_service = ReportSpecService() self.simulation_service = SimulationService() + self.report_output_id_map_service = ReportOutputIdMapService() + self.report_run_service = ReportRunService() def _lock_clause(self) -> str: return "" if database.local else " FOR UPDATE" @@ -90,6 +96,17 @@ def _get_report_output_row( row: Row | None = queryer.query(query, tuple(params)).fetchone() return dict(row) if row is not None else None + def _get_last_inserted_report_output_id(self, tx) -> int: + query = ( + "SELECT last_insert_rowid() AS id" + if database.local + else "SELECT LAST_INSERT_ID() AS id" + ) + row = tx.query(query).fetchone() + if row is None or row["id"] is None: + raise Exception("Failed to retrieve inserted report output ID") + return int(row["id"]) + def _get_linked_simulations( self, report_output: dict, @@ -178,6 +195,33 @@ def _list_report_runs_descending( runs.append(run) return runs + def _get_report_run_row( + self, + run_id: str, + *, + queryer=None, + report_output_id: int | None = None, + for_update: bool = False, + ) -> dict | None: + queryer = queryer or database + query = "SELECT * FROM report_output_runs WHERE id = ?" + params: list[str | int] = [run_id] + if report_output_id is not None: + query += " AND report_output_id = ?" + params.append(report_output_id) + if for_update: + query += self._lock_clause() + + row: Row | None = queryer.query(query, tuple(params)).fetchone() + if row is None: + return None + + run = dict(row) + run["report_spec_snapshot_json"] = parse_json_field( + run.get("report_spec_snapshot_json") + ) + return run + def _select_mutable_run( self, report_output: dict, runs_descending: list[dict] ) -> dict | None: @@ -276,18 +320,30 @@ def _derive_report_country_package_version( return versions[0] return None - def _build_version_manifest( + def _merge_version_manifest_overrides( + self, + version_manifest: dict[str, str | None], + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + merged_manifest = dict(version_manifest) + for key, value in (version_manifest_overrides or {}).items(): + if key in merged_manifest and value is not None: + merged_manifest[key] = value + return merged_manifest + + def _build_bootstrap_version_manifest( self, report_output: dict, report_spec: ReportSpec | None, simulation_1: dict | None = None, simulation_2: dict | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict[str, str | None]: resolved_dataset = None if report_spec is not None and report_spec.report_kind in ECONOMY_REPORT_KINDS: resolved_dataset = report_spec.dataset - return { + version_manifest = { "country_package_version": self._derive_report_country_package_version( simulation_1, simulation_2 ), @@ -300,22 +356,162 @@ def _build_version_manifest( "resolved_dataset": resolved_dataset, "resolved_options_hash": None, } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) + + def _build_existing_run_version_manifest( + self, + run: dict, + report_output: dict, + report_spec: ReportSpec | None, + simulation_1: dict | None = None, + simulation_2: dict | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + fallback_manifest = self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + version_manifest = { + key: run.get(key) + if run.get(key) is not None + else fallback_manifest.get(key) + for key in fallback_manifest + } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) def _get_report_spec_status(self, report_spec: ReportSpec) -> str: if report_spec.report_kind in ECONOMY_REPORT_KINDS: return "backfilled_assumed" return "explicit" - def _upsert_report_spec_in_transaction( + def _persist_explicit_report_spec_in_transaction( self, tx, report_output: dict, - simulation_1: dict | None, + simulation_1: dict, + simulation_2: dict | None, + explicit_report_spec: ReportSpec, + report_spec_schema_version: int | None = None, + ) -> ReportSpec: + schema_version = ( + report_spec_schema_version + if report_spec_schema_version is not None + else REPORT_SPEC_SCHEMA_VERSION + ) + self.report_spec_service._validate_schema_version(schema_version) + self.report_spec_service.validate_report_spec_matches_context( + report_output, + explicit_report_spec, + simulation_1, + simulation_2, + ) + report_spec_status = "explicit" + existing_spec = parse_json_field(report_output.get("report_spec_json")) + if ( + existing_spec != explicit_report_spec.model_dump() + or report_output.get("report_kind") != explicit_report_spec.report_kind + or report_output.get("report_spec_schema_version") != schema_version + or report_output.get("report_spec_status") != report_spec_status + ): + tx.query( + """ + UPDATE report_outputs + SET report_kind = ?, report_spec_json = ?, + report_spec_schema_version = ?, report_spec_status = ? + WHERE id = ? + """, + ( + explicit_report_spec.report_kind, + explicit_report_spec.model_dump_json(), + schema_version, + report_spec_status, + report_output["id"], + ), + ) + report_output["report_kind"] = explicit_report_spec.report_kind + report_output["report_spec_json"] = explicit_report_spec.model_dump() + report_output["report_spec_schema_version"] = schema_version + report_output["report_spec_status"] = report_spec_status + return explicit_report_spec + + def _sync_report_identity_in_transaction( + self, + tx, + report_output: dict, + report_spec: ReportSpec | None, + ) -> None: + if report_spec is None: + return + + report_identity_hash, report_identity_schema_version = ( + self.report_spec_service.get_report_identity(report_spec) + ) + if ( + report_output.get("report_identity_hash") == report_identity_hash + and report_output.get("report_identity_schema_version") + == report_identity_schema_version + ): + return + + tx.query( + """ + UPDATE report_outputs + SET report_identity_hash = ?, report_identity_schema_version = ? + WHERE id = ? + """, + ( + report_identity_hash, + report_identity_schema_version, + report_output["id"], + ), + ) + report_output["report_identity_hash"] = report_identity_hash + report_output["report_identity_schema_version"] = report_identity_schema_version + + def _load_existing_explicit_report_spec( + self, + report_output: dict, + simulation_1: dict, simulation_2: dict | None, ) -> ReportSpec | None: - if simulation_1 is None: + if report_output.get("report_spec_status") != "explicit": return None + raw_spec = parse_json_field(report_output.get("report_spec_json")) + if raw_spec is None: + raise ValueError("Stored explicit report spec is missing report_spec_json") + + report_spec = self.report_spec_service.parse_report_spec( + raw_spec, + schema_version=report_output.get("report_spec_schema_version"), + ) + if report_output.get("report_kind") != report_spec.report_kind: + raise ValueError( + "Stored explicit report kind must match stored report spec" + ) + self.report_spec_service.validate_report_spec_matches_context( + report_output, + report_spec, + simulation_1, + simulation_2, + ) + return report_spec + + def _derive_and_upsert_report_spec_in_transaction( + self, + tx, + report_output: dict, + simulation_1: dict, + simulation_2: dict | None, + ) -> ReportSpec | None: try: report_spec = self.report_spec_service.build_report_spec( report_output=report_output, @@ -359,6 +555,51 @@ def _upsert_report_spec_in_transaction( return report_spec + def _upsert_report_spec_in_transaction( + self, + tx, + report_output: dict, + simulation_1: dict | None, + simulation_2: dict | None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, + ) -> ReportSpec | None: + if simulation_1 is None: + if explicit_report_spec is not None: + raise ValueError( + "Explicit report specs require linked simulations to be present" + ) + if report_output.get("report_spec_status") == "explicit": + raise ValueError( + "Stored explicit report specs require linked simulations to be present" + ) + return None + + if explicit_report_spec is not None: + return self._persist_explicit_report_spec_in_transaction( + tx, + report_output, + simulation_1, + simulation_2, + explicit_report_spec, + report_spec_schema_version=report_spec_schema_version, + ) + + stored_explicit_report_spec = self._load_existing_explicit_report_spec( + report_output, + simulation_1, + simulation_2, + ) + if stored_explicit_report_spec is not None: + return stored_explicit_report_spec + + return self._derive_and_upsert_report_spec_in_transaction( + tx, + report_output, + simulation_1, + simulation_2, + ) + def _run_matches_parent( self, run: dict, @@ -398,47 +639,62 @@ def _insert_bootstrap_report_run( report_spec: ReportSpec | None, version_manifest: dict[str, str | None], ) -> None: - requested_at = self._utc_timestamp() - is_terminal = report_output["status"] in ("complete", "error") - has_started = report_output["status"] in ("running", "complete", "error") - started_at = requested_at if has_started else None - finished_at = requested_at if is_terminal else None + self.report_run_service.create_report_output_run_in_transaction( + tx, + report_output["id"], + status=report_output["status"], + trigger_type="initial", + output=report_output.get("output"), + error_message=report_output.get("error_message"), + report_spec_snapshot=( + report_spec.model_dump() if report_spec is not None else None + ), + version_manifest=version_manifest, + ) - tx.query( - """ - INSERT INTO report_output_runs ( - id, report_output_id, run_sequence, status, output, error_message, - trigger_type, requested_at, started_at, finished_at, source_run_id, - report_spec_snapshot_json, country_package_version, policyengine_version, - data_version, runtime_app_name, report_cache_version, - simulation_cache_version, requested_version_override, resolved_dataset, - resolved_options_hash - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - str(uuid.uuid4()), - report_output["id"], - 1, - report_output["status"], - serialize_json_field(report_output.get("output")), - report_output.get("error_message"), - "initial", - requested_at, - started_at, - finished_at, - None, - (report_spec.model_dump_json() if report_spec is not None else None), - version_manifest["country_package_version"], - version_manifest["policyengine_version"], - version_manifest["data_version"], - version_manifest["runtime_app_name"], - version_manifest["report_cache_version"], - version_manifest["simulation_cache_version"], - version_manifest["requested_version_override"], - version_manifest["resolved_dataset"], - version_manifest["resolved_options_hash"], + def _ensure_current_report_cache_run_in_transaction( + self, + tx, + report_output: dict, + report_spec: ReportSpec | None, + simulation_1: dict | None, + simulation_2: dict | None, + runs_descending: list[dict], + ) -> list[dict]: + current_report_cache_version = get_report_output_cache_version( + report_output["country_id"] + ) + if any( + run.get("report_cache_version") == current_report_cache_version + for run in runs_descending + ): + return runs_descending + + source_run = select_display_report_run(report_output, runs_descending) + version_manifest = self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + version_manifest["report_cache_version"] = current_report_cache_version + + self.report_run_service.create_report_output_run_in_transaction( + tx, + report_output["id"], + status="pending", + trigger_type="rerun", + source_run_id=(source_run["id"] if source_run is not None else None), + report_spec_snapshot=( + report_spec.model_dump() if report_spec is not None else None ), + version_manifest=version_manifest, ) + report_output["status"] = "pending" + report_output["output"] = None + report_output["error_message"] = None + report_output["api_version"] = current_report_cache_version + return self._list_report_runs_descending(report_output["id"], queryer=tx) def _update_report_run_in_transaction( self, @@ -539,12 +795,67 @@ def _sync_parent_pointers_in_transaction( report_output["active_run_id"] = desired_active_run_id report_output["latest_successful_run_id"] = desired_latest_successful_run_id + def _sync_parent_mirror_from_display_run_in_transaction( + self, + tx, + report_output: dict, + runs_descending: list[dict], + ) -> dict: + self._sync_parent_pointers_in_transaction(tx, report_output, runs_descending) + display_run = select_display_report_run(report_output, runs_descending) + if display_run is None: + refreshed_report_output = self._get_report_output_row( + report_output["id"], + queryer=tx, + country_id=report_output["country_id"], + ) + if refreshed_report_output is None: + raise ValueError( + f"Report output #{report_output['id']} not found after sync" + ) + return refreshed_report_output + + parent_api_version = ( + display_run["report_cache_version"] + if display_run.get("report_cache_version") is not None + else report_output.get("api_version") + ) + tx.query( + """ + UPDATE report_outputs + SET status = ?, output = ?, error_message = ?, api_version = ? + WHERE id = ? AND country_id = ? + """, + ( + display_run["status"], + serialize_json_field(display_run.get("output")), + display_run.get("error_message"), + parent_api_version, + report_output["id"], + report_output["country_id"], + ), + ) + refreshed_report_output = self._get_report_output_row( + report_output["id"], + queryer=tx, + country_id=report_output["country_id"], + ) + if refreshed_report_output is None: + raise ValueError( + f"Report output #{report_output['id']} not found after sync" + ) + return refreshed_report_output + def _ensure_report_output_dual_write_state_in_transaction( self, tx, report_output_id: int, *, country_id: str | None = None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, + ensure_current_report_cache_run: bool = False, ) -> dict: report_output = self._get_report_output_row( report_output_id, @@ -562,6 +873,8 @@ def _ensure_report_output_dual_write_state_in_transaction( bootstrap_dual_write_state=True, ) except ValueError as exc: + if explicit_report_spec is not None: + raise print( "Skipping linked simulation sync for report output " f"#{report_output_id}. Details: {str(exc)}" @@ -573,17 +886,21 @@ def _ensure_report_output_dual_write_state_in_transaction( report_output, simulation_1, simulation_2, + explicit_report_spec=explicit_report_spec, + report_spec_schema_version=report_spec_schema_version, ) - version_manifest = self._build_version_manifest( - report_output, - report_spec=report_spec, - simulation_1=simulation_1, - simulation_2=simulation_2, - ) + self._sync_report_identity_in_transaction(tx, report_output, report_spec) runs_descending = self._list_report_runs_descending( report_output_id, queryer=tx ) if not runs_descending: + version_manifest = self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) self._insert_bootstrap_report_run( tx, report_output, @@ -595,6 +912,24 @@ def _ensure_report_output_dual_write_state_in_transaction( ) else: mutable_run = self._select_mutable_run(report_output, runs_descending) + version_manifest = ( + self._build_existing_run_version_manifest( + mutable_run, + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + if mutable_run is not None + else self._build_bootstrap_version_manifest( + report_output, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + ) if mutable_run is not None: run_matches_parent = self._run_matches_parent( mutable_run, @@ -621,6 +956,26 @@ def _ensure_report_output_dual_write_state_in_transaction( report_output_id, queryer=tx ) + if ensure_current_report_cache_run: + runs_descending = self._ensure_current_report_cache_run_in_transaction( + tx, + report_output, + report_spec, + simulation_1, + simulation_2, + runs_descending, + ) + refreshed_report_output = ( + self._sync_parent_mirror_from_display_run_in_transaction( + tx, + report_output, + runs_descending, + ) + ) + return self._with_display_run_timestamps( + refreshed_report_output, queryer=tx + ) + self._sync_parent_pointers_in_transaction(tx, report_output, runs_descending) refreshed_report_output = self._get_report_output_row( report_output_id, @@ -635,20 +990,38 @@ def ensure_report_output_dual_write_state( self, report_output_id: int, country_id: str | None = None, + explicit_report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, + ensure_current_report_cache_run: bool = False, ) -> dict: return database.transaction( lambda tx: self._ensure_report_output_dual_write_state_in_transaction( tx, report_output_id, country_id=country_id, + explicit_report_spec=explicit_report_spec, + report_spec_schema_version=report_spec_schema_version, + version_manifest_overrides=version_manifest_overrides, + ensure_current_report_cache_run=ensure_current_report_cache_run, ) ) + def parse_report_spec_payload( + self, + raw_report_spec: dict, + schema_version: int = REPORT_SPEC_SCHEMA_VERSION, + ) -> ReportSpec: + return self.report_spec_service.parse_report_spec( + raw_report_spec, + schema_version=schema_version, + ) + def get_stored_report_output( self, country_id: str, report_output_id: int ) -> dict | None: """ - Get a stored report output row without aliasing to current runtime lineage. + Get a stored report output row without resolving legacy ID mappings. This is used by mutation paths that must address the originally requested row. It still runs dual-write synchronization, so it may @@ -670,15 +1043,13 @@ def get_stored_report_output( def report_output_exists(self, country_id: str, report_output_id: int) -> bool: return ( - self._get_report_output_row(report_output_id, country_id=country_id) + self.report_output_id_map_service.resolve_report_output_id( + report_output_id, + country_id=country_id, + ) is not None ) - def _is_current_report_output(self, report_output: dict) -> bool: - return report_output.get("api_version") == get_report_output_cache_version( - report_output["country_id"] - ) - def _find_existing_report_output_row( self, *, @@ -689,73 +1060,301 @@ def _find_existing_report_output_row( queryer=None, ) -> dict | None: queryer = queryer or database - api_version = get_report_output_cache_version(country_id) query = """ SELECT * FROM report_outputs - WHERE country_id = ? AND simulation_1_id = ? AND year = ? AND api_version = ? + WHERE country_id = ? AND simulation_1_id = ? AND year = ? """ - params: list[int | str] = [country_id, simulation_1_id, year, api_version] + params: list[int | str] = [country_id, simulation_1_id, year] if simulation_2_id is not None: query += " AND simulation_2_id = ?" params.append(simulation_2_id) else: query += " AND simulation_2_id IS NULL" - query += " ORDER BY id DESC" + query += " ORDER BY id ASC" row = queryer.query(query, tuple(params)).fetchone() return dict(row) if row is not None else None - def _get_or_create_current_report_output(self, report_output: dict) -> dict: - current_report = self.find_existing_report_output( - country_id=report_output["country_id"], - simulation_1_id=report_output["simulation_1_id"], - simulation_2_id=report_output["simulation_2_id"], - year=report_output["year"], - ) - if current_report is not None: - return self._with_display_run_timestamps(current_report) - - return self.create_report_output( - country_id=report_output["country_id"], - simulation_1_id=report_output["simulation_1_id"], - simulation_2_id=report_output["simulation_2_id"], - year=report_output["year"], - ) - - def _alias_report_output(self, report_output_id: int, report_output: dict) -> dict: - aliased_report = dict(report_output) - aliased_report["id"] = report_output_id - return aliased_report - - def find_existing_report_output( + def _find_existing_report_output_row_by_identity( self, + *, country_id: str, - simulation_1_id: int, - simulation_2_id: int | None = None, - year: str = "2025", + report_identity_hash: str, + report_identity_schema_version: int, + queryer=None, ) -> dict | None: - """ - Find an existing report output with the same simulation IDs and year. - """ - print("Checking for existing report output") - - try: - existing_report = self._find_existing_report_output_row( + queryer = queryer or database + row = queryer.query( + """ + SELECT * FROM report_outputs + WHERE country_id = ? AND report_identity_hash = ? + AND report_identity_schema_version = ? + ORDER BY id ASC + """, + ( + country_id, + report_identity_hash, + report_identity_schema_version, + ), + ).fetchone() + return dict(row) if row is not None else None + + def _list_report_output_rows_by_legacy_key( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + queryer=None, + ) -> list[dict]: + queryer = queryer or database + query = """ + SELECT * FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND year = ? + """ + params: list[int | str] = [country_id, simulation_1_id, year] + if simulation_2_id is not None: + query += " AND simulation_2_id = ?" + params.append(simulation_2_id) + else: + query += " AND simulation_2_id IS NULL" + query += " ORDER BY id ASC" + + rows = queryer.query(query, tuple(params)).fetchall() + return [dict(row) for row in rows] + + def _build_report_spec_for_create( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + queryer=None, + ) -> ReportSpec: + queryer = queryer or database + simulation_1 = self._require_simulation_exists( + queryer, + country_id=country_id, + simulation_id=simulation_1_id, + ) + + simulation_2 = None + if simulation_2_id is not None: + simulation_2 = self._require_simulation_exists( + queryer, + country_id=country_id, + simulation_id=simulation_2_id, + ) + + return self.report_spec_service.build_report_spec( + report_output={ + "country_id": country_id, + "simulation_1_id": simulation_1_id, + "simulation_2_id": simulation_2_id, + "year": year, + }, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + + def _validate_explicit_report_spec_for_create( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + report_spec: ReportSpec, + queryer, + ) -> None: + simulation_1 = self._require_simulation_exists( + queryer, + country_id=country_id, + simulation_id=simulation_1_id, + ) + simulation_2 = None + if simulation_2_id is not None: + simulation_2 = self._require_simulation_exists( + queryer, + country_id=country_id, + simulation_id=simulation_2_id, + ) + + self.report_spec_service.validate_report_spec_matches_context( + { + "country_id": country_id, + "simulation_1_id": simulation_1_id, + "simulation_2_id": simulation_2_id, + "year": year, + }, + report_spec, + simulation_1, + simulation_2, + ) + + def _get_report_spec_for_identity_matching( + self, + report_output: dict, + *, + queryer=None, + ) -> ReportSpec | None: + queryer = queryer or database + try: + simulation_1, simulation_2 = self._get_linked_simulations( + report_output, + queryer=queryer, + ) + except ValueError: + return None + + raw_spec = parse_json_field(report_output.get("report_spec_json")) + if ( + raw_spec is not None + and report_output.get("report_spec_schema_version") is not None + ): + try: + report_spec = self.report_spec_service.parse_report_spec( + raw_spec, + schema_version=report_output["report_spec_schema_version"], + ) + self.report_spec_service.validate_report_spec_matches_context( + report_output, + report_spec, + simulation_1, + simulation_2, + ) + return report_spec + except ValueError: + return None + + try: + return self.report_spec_service.build_report_spec( + report_output=report_output, + simulation_1=simulation_1, + simulation_2=simulation_2, + ) + except ValueError: + return None + + def _find_existing_report_output_for_create( + self, + *, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None, + year: str, + report_spec: ReportSpec | None = None, + queryer=None, + ) -> dict | None: + queryer = queryer or database + if report_spec is not None: + self._validate_explicit_report_spec_for_create( country_id=country_id, simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, + queryer=queryer, + ) + + identity_report_spec = report_spec or self._build_report_spec_for_create( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + queryer=queryer, + ) + report_identity_hash, report_identity_schema_version = ( + self.report_spec_service.get_report_identity(identity_report_spec) + ) + existing_report = self._find_existing_report_output_row_by_identity( + country_id=country_id, + report_identity_hash=report_identity_hash, + report_identity_schema_version=report_identity_schema_version, + queryer=queryer, + ) + if existing_report is not None: + return existing_report + + candidate_rows = self._list_report_output_rows_by_legacy_key( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + queryer=queryer, + ) + for candidate_row in candidate_rows: + candidate_report_spec = self._get_report_spec_for_identity_matching( + candidate_row, + queryer=queryer, + ) + if candidate_report_spec is None: + continue + candidate_identity_hash, candidate_identity_schema_version = ( + self.report_spec_service.get_report_identity(candidate_report_spec) + ) + if ( + candidate_identity_hash == report_identity_hash + and candidate_identity_schema_version == report_identity_schema_version + ): + return candidate_row + + return None + + def _with_requested_report_output_id( + self, report_output_id: int, report_output: dict + ) -> dict: + response_report = dict(report_output) + response_report["id"] = report_output_id + return response_report + + def _merge_display_run_into_report_output( + self, + report_output: dict, + display_run: dict | None, + ) -> dict: + if display_run is None: + return dict(report_output) + + result = dict(report_output) + result["status"] = display_run["status"] + result["output"] = display_run.get("output") + result["error_message"] = display_run.get("error_message") + if display_run.get("report_cache_version") is not None: + result["api_version"] = display_run["report_cache_version"] + for field in ("requested_at", "started_at", "finished_at"): + result[field] = self._format_run_timestamp(display_run.get(field)) + return result + + def find_existing_report_output_for_create( + self, + country_id: str, + simulation_1_id: int, + simulation_2_id: int | None = None, + year: str = "2025", + report_spec: ReportSpec | None = None, + ) -> dict | None: + try: + existing_report = self._find_existing_report_output_for_create( + country_id=country_id, + simulation_1_id=simulation_1_id, + simulation_2_id=simulation_2_id, + year=year, + report_spec=report_spec, ) if existing_report is not None: - print(f"Found existing report output with ID: {existing_report['id']}") - return self.ensure_report_output_dual_write_state( - existing_report["id"], - country_id=country_id, + print( + "Found existing report output for create with ID: " + f"{existing_report['id']}" ) - return None - + return existing_report except Exception as e: - print(f"Error checking for existing report output. Details: {str(e)}") + print( + "Error checking for existing report output by canonical identity. " + f"Details: {str(e)}" + ) raise e def create_report_output( @@ -764,6 +1363,8 @@ def create_report_output( simulation_1_id: int, simulation_2_id: int | None = None, year: str = "2025", + report_spec: ReportSpec | None = None, + report_spec_schema_version: int | None = None, ) -> dict: """ Create a new report output record with pending status. @@ -774,11 +1375,12 @@ def create_report_output( try: def tx_callback(tx): - existing_report = self._find_existing_report_output_row( + existing_report = self._find_existing_report_output_for_create( country_id=country_id, simulation_1_id=simulation_1_id, simulation_2_id=simulation_2_id, year=year, + report_spec=report_spec, queryer=tx, ) if existing_report is not None: @@ -789,6 +1391,9 @@ def tx_callback(tx): tx, existing_report["id"], country_id=country_id, + explicit_report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, + ensure_current_report_cache_run=True, ) self._require_simulation_exists( @@ -835,12 +1440,11 @@ def tx_callback(tx): ), ) - created_report = self._find_existing_report_output_row( - country_id=country_id, - simulation_1_id=simulation_1_id, - simulation_2_id=simulation_2_id, - year=year, + created_report_id = self._get_last_inserted_report_output_id(tx) + created_report = self._get_report_output_row( + created_report_id, queryer=tx, + country_id=country_id, ) if created_report is None: raise Exception("Failed to retrieve created report output") @@ -850,6 +1454,8 @@ def tx_callback(tx): tx, created_report["id"], country_id=country_id, + explicit_report_spec=report_spec, + report_spec_schema_version=report_spec_schema_version, ) return database.transaction(tx_callback) @@ -870,21 +1476,61 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No f"Invalid report output ID: {report_output_id}. Must be a positive integer." ) - report_output = self._get_report_output_row( + resolution = self.report_output_id_map_service.resolve_report_output_id( report_output_id, country_id=country_id, ) - if report_output is None: + if resolution is None: return None - if self._is_current_report_output(report_output): - return self.ensure_report_output_dual_write_state( + canonical_report_output_id = resolution["canonical_report_output_id"] + canonical_report_output = self._get_report_output_row( + canonical_report_output_id, + country_id=country_id, + ) + if canonical_report_output is None: + return None + + if resolution["is_legacy_id"]: + display_run = self._get_report_run_row( + resolution["display_report_output_run_id"], + report_output_id=canonical_report_output_id, + ) + if display_run is None: + raise ValueError( + "Legacy ID mapping points to missing display report output " + f"run #{resolution['display_report_output_run_id']}" + ) + else: + display_run = self.report_run_service.select_display_run( + canonical_report_output + ) + if display_run is None or ( + run_matches_report_result(display_run, canonical_report_output) + and self._run_needs_timestamp_sync( + display_run, + canonical_report_output["status"], + ) + ): + canonical_report_output = ( + self.ensure_report_output_dual_write_state( + canonical_report_output_id, + country_id=country_id, + ) + ) + display_run = self.report_run_service.select_display_run( + canonical_report_output + ) + resolved_report_output = self._merge_display_run_into_report_output( + canonical_report_output, + display_run, + ) + if resolution["is_legacy_id"]: + return self._with_requested_report_output_id( report_output_id, - country_id=country_id, + resolved_report_output, ) - - current_report = self._get_or_create_current_report_output(report_output) - return self._alias_report_output(report_output_id, current_report) + return resolved_report_output except Exception as e: print( @@ -892,6 +1538,196 @@ def get_report_output(self, country_id: str, report_output_id: int) -> dict | No ) raise e + def get_report_output_for_run( + self, + country_id: str, + report_output_id: int, + report_output_run_id: str, + ) -> dict | None: + """ + Get a report output projected through one explicit run. + + Normal report reads intentionally apply display-run selection. PATCH + responses for an explicit run need the narrower projection so workers + see the run they just updated, even if it is not the report's display + run. + """ + resolution = self.report_output_id_map_service.resolve_report_output_id( + report_output_id, + country_id=country_id, + ) + if resolution is None: + return None + + canonical_report_output_id = resolution["canonical_report_output_id"] + canonical_report_output = self._get_report_output_row( + canonical_report_output_id, + country_id=country_id, + ) + if canonical_report_output is None: + return None + + explicit_run = self._get_report_run_row( + report_output_run_id, + report_output_id=canonical_report_output_id, + ) + if explicit_run is None: + return None + + resolved_report_output = self._merge_display_run_into_report_output( + canonical_report_output, + explicit_run, + ) + if resolution["is_legacy_id"]: + return self._with_requested_report_output_id( + report_output_id, + resolved_report_output, + ) + return resolved_report_output + + def create_report_rerun( + self, + country_id: str, + report_output_id: int, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict: + """ + Create a new pending run for the canonical report resolved from the + requested report ID. + """ + print(f"Creating report rerun for report output {report_output_id}") + + def tx_callback(tx): + resolution = self.report_output_id_map_service.resolve_report_output_id( + report_output_id, + queryer=tx, + country_id=country_id, + ) + if resolution is None: + raise ValueError(f"Report output #{report_output_id} not found") + + canonical_report_id = resolution["canonical_report_output_id"] + canonical_report = ( + self._ensure_report_output_dual_write_state_in_transaction( + tx, + canonical_report_id, + country_id=country_id, + ) + ) + canonical_report = self._get_report_output_row( + canonical_report_id, + queryer=tx, + country_id=country_id, + for_update=True, + ) + if canonical_report is None: + raise ValueError(f"Report output #{report_output_id} not found") + + simulation_1, simulation_2 = self._get_linked_simulations( + canonical_report, + queryer=tx, + bootstrap_dual_write_state=True, + ) + report_spec = self._upsert_report_spec_in_transaction( + tx, + canonical_report, + simulation_1, + simulation_2, + ) + existing_runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + source_report_run = select_display_report_run( + canonical_report, + existing_runs_descending, + ) + report_version_manifest = ( + self._build_existing_run_version_manifest( + source_report_run, + canonical_report, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + if source_report_run is not None + else self._build_bootstrap_version_manifest( + canonical_report, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + ) + report_run = ( + self.report_run_service.create_report_output_run_in_transaction( + tx, + canonical_report_id, + status="pending", + trigger_type="rerun", + source_run_id=( + source_report_run["id"] + if source_report_run is not None + else None + ), + report_spec_snapshot=( + report_spec.model_dump() if report_spec is not None else None + ), + version_manifest=report_version_manifest, + ) + ) + report_run_id = report_run["id"] + + simulation_run_ids: list[str] = [] + for input_position, simulation in ( + (1, simulation_1), + (2, simulation_2), + ): + if simulation is None: + continue + + simulation_run = self.simulation_service.create_report_rerun_simulation_run_in_transaction( + tx, + simulation, + report_output_run_id=report_run_id, + input_position=input_position, + ) + simulation_run_ids.append(simulation_run["id"]) + + canonical_report["status"] = "pending" + canonical_report["output"] = None + canonical_report["error_message"] = None + report_runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + refreshed_report = self._sync_parent_mirror_from_display_run_in_transaction( + tx, + canonical_report, + report_runs_descending, + ) + selected_report = self._merge_display_run_into_report_output( + refreshed_report, + self._get_report_run_row(report_run_id, queryer=tx), + ) + return { + "requested_report_output_id": report_output_id, + "report_output_id": canonical_report_id, + "report_output_run_id": report_run_id, + "simulation_run_ids": simulation_run_ids, + "report_spec": ( + report_spec.model_dump() if report_spec is not None else None + ), + "report": selected_report, + } + + try: + return database.transaction(tx_callback) + except Exception as e: + print(f"Error creating report rerun #{report_output_id}. Details: {str(e)}") + raise e + def update_report_output( self, country_id: str, @@ -899,6 +1735,8 @@ def update_report_output( status: str | None = None, output: str | None = None, error_message: str | None = None, + report_output_run_id: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> bool: """ Update a report output record with results or error. @@ -906,51 +1744,165 @@ def update_report_output( print(f"Updating report output {report_id}") try: - update_fields = [] - update_values = [] - - if status is not None: - update_fields.append("status = ?") - update_values.append(status) - - if output is not None: - update_fields.append("output = ?") - update_values.append(output) - - if error_message is not None: - update_fields.append("error_message = ?") - update_values.append(error_message) - - if not update_fields: + has_user_fields = ( + status is not None or output is not None or error_message is not None + ) + if not has_user_fields and not version_manifest_overrides: print("No fields to update") return False def tx_callback(tx): - requested_report = self._get_report_output_row( + resolution = self.report_output_id_map_service.resolve_report_output_id( report_id, queryer=tx, country_id=country_id, + ) + if resolution is None: + raise ValueError(f"Report output #{report_id} not found") + + canonical_report_id = resolution["canonical_report_output_id"] + canonical_report = self._get_report_output_row( + canonical_report_id, + queryer=tx, + country_id=country_id, for_update=True, ) - if requested_report is None: + if canonical_report is None: raise ValueError(f"Report output #{report_id} not found") - if status == "running" and not self._has_mutable_running_run( - requested_report, queryer=tx + try: + simulation_1, simulation_2 = self._get_linked_simulations( + canonical_report, + queryer=tx, + bootstrap_dual_write_state=True, + ) + except ValueError as exc: + print( + "Skipping linked simulation sync for report output " + f"#{canonical_report_id}. Details: {str(exc)}" + ) + simulation_1, simulation_2 = None, None + + report_spec = self._upsert_report_spec_in_transaction( + tx, + canonical_report, + simulation_1, + simulation_2, + ) + self._sync_report_identity_in_transaction( + tx, + canonical_report, + report_spec, + ) + + runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + if not runs_descending: + version_manifest = self._build_bootstrap_version_manifest( + canonical_report, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + self._insert_bootstrap_report_run( + tx, + canonical_report, + report_spec, + version_manifest, + ) + runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + + if report_output_run_id is not None: + mutable_run = self._get_report_run_row( + report_output_run_id, + queryer=tx, + report_output_id=canonical_report_id, + for_update=True, + ) + if mutable_run is None: + raise ValueError( + "Report output run " + f"#{report_output_run_id} not found for report " + f"#{canonical_report_id}" + ) + elif resolution["is_legacy_id"]: + mutable_run = self._get_report_run_row( + resolution["display_report_output_run_id"], + queryer=tx, + report_output_id=canonical_report_id, + for_update=True, + ) + if mutable_run is None: + raise ValueError( + "Legacy ID mapping points to missing display report " + "output run " + f"#{resolution['display_report_output_run_id']}" + ) + else: + mutable_run = self._select_mutable_run( + canonical_report, + runs_descending, + ) + + if mutable_run is None: + raise ValueError( + "Cannot update report output without an active report run" + ) + + if status == "running" and mutable_run["status"] not in ( + "pending", + "running", ): raise ValueError( "Cannot mark report output running without an active " "pending or running report run" ) - tx.query( - f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", - (*update_values, report_id, country_id), + run_update_state = dict(canonical_report) + run_update_state["status"] = ( + status if status is not None else mutable_run["status"] ) - self._ensure_report_output_dual_write_state_in_transaction( + run_update_state["output"] = ( + output if output is not None else mutable_run.get("output") + ) + run_update_state["error_message"] = ( + error_message + if error_message is not None + else mutable_run.get("error_message") + ) + + version_manifest = self._build_existing_run_version_manifest( + mutable_run, + canonical_report, + report_spec=report_spec, + simulation_1=simulation_1, + simulation_2=simulation_2, + version_manifest_overrides=version_manifest_overrides, + ) + self._update_report_run_in_transaction( tx, - report_id, - country_id=country_id, + run_id=mutable_run["id"], + report_output=run_update_state, + report_spec=report_spec, + version_manifest=version_manifest, + ) + canonical_report["status"] = run_update_state["status"] + canonical_report["output"] = run_update_state["output"] + canonical_report["error_message"] = run_update_state["error_message"] + runs_descending = self._list_report_runs_descending( + canonical_report_id, + queryer=tx, + ) + self._sync_parent_mirror_from_display_run_in_transaction( + tx, + canonical_report, + runs_descending, ) database.transaction(tx_callback) diff --git a/policyengine_api/services/report_run_service.py b/policyengine_api/services/report_run_service.py index 9899f6cc9..74fc61e96 100644 --- a/policyengine_api/services/report_run_service.py +++ b/policyengine_api/services/report_run_service.py @@ -55,69 +55,98 @@ def create_report_output_run( report_spec_snapshot: dict[str, Any] | str | None = None, version_manifest: dict[str, str | None] | None = None, run_id: str | None = None, + ) -> dict: + def create_run_transaction(tx) -> dict: + return self.create_report_output_run_in_transaction( + tx, + report_output_id, + status=status, + trigger_type=trigger_type, + output=output, + error_message=error_message, + source_run_id=source_run_id, + report_spec_snapshot=report_spec_snapshot, + version_manifest=version_manifest, + run_id=run_id, + ) + + return database.transaction(create_run_transaction) + + def create_report_output_run_in_transaction( + self, + tx, + report_output_id: int, + status: str = "pending", + trigger_type: str = "initial", + output: dict[str, Any] | list[Any] | str | None = None, + error_message: str | None = None, + source_run_id: str | None = None, + report_spec_snapshot: dict[str, Any] | str | None = None, + version_manifest: dict[str, str | None] | None = None, + run_id: str | None = None, ) -> dict: run_id = run_id or str(uuid.uuid4()) version_manifest = version_manifest or {} lock_clause = "" if database.local else " FOR UPDATE" - def create_run_transaction(tx) -> None: - parent_row: Row | None = tx.query( - f"SELECT id FROM report_outputs WHERE id = ?{lock_clause}", - (report_output_id,), - ).fetchone() - if parent_row is None: - raise ValueError(f"Report output #{report_output_id} not found") - - run_sequence_row: Row | None = tx.query( - """ - SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence - FROM report_output_runs - WHERE report_output_id = ? - """, - (report_output_id,), - ).fetchone() - run_sequence = ( - int(run_sequence_row["max_run_sequence"]) + 1 - if run_sequence_row is not None - else 1 - ) + parent_row: Row | None = tx.query( + f"SELECT id FROM report_outputs WHERE id = ?{lock_clause}", + (report_output_id,), + ).fetchone() + if parent_row is None: + raise ValueError(f"Report output #{report_output_id} not found") - requested_at = self._utc_timestamp() - is_terminal = status in ("complete", "error") - has_started = status in ("running", "complete", "error") - started_at = requested_at if has_started else None - finished_at = requested_at if is_terminal else None - - tx.query( - f""" - INSERT INTO report_output_runs ( - id, report_output_id, run_sequence, status, output, error_message, - trigger_type, requested_at, started_at, finished_at, source_run_id, - report_spec_snapshot_json, {", ".join(REPORT_RUN_VERSION_FIELDS)} - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - run_id, - report_output_id, - run_sequence, - status, - self._serialize_json(output), - error_message, - trigger_type, - requested_at, - started_at, - finished_at, - source_run_id, - self._serialize_json(report_spec_snapshot), - *[ - version_manifest.get(field) - for field in REPORT_RUN_VERSION_FIELDS - ], - ), - ) + run_sequence_row: Row | None = tx.query( + """ + SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence + FROM report_output_runs + WHERE report_output_id = ? + """, + (report_output_id,), + ).fetchone() + run_sequence = ( + int(run_sequence_row["max_run_sequence"]) + 1 + if run_sequence_row is not None + else 1 + ) - database.transaction(create_run_transaction) - return self.get_report_output_run(run_id) + requested_at = self._utc_timestamp() + is_terminal = status in ("complete", "error") + has_started = status in ("running", "complete", "error") + started_at = requested_at if has_started else None + finished_at = requested_at if is_terminal else None + + tx.query( + f""" + INSERT INTO report_output_runs ( + id, report_output_id, run_sequence, status, output, error_message, + trigger_type, requested_at, started_at, finished_at, source_run_id, + report_spec_snapshot_json, {", ".join(REPORT_RUN_VERSION_FIELDS)} + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + run_id, + report_output_id, + run_sequence, + status, + self._serialize_json(output), + error_message, + trigger_type, + requested_at, + started_at, + finished_at, + source_run_id, + self._serialize_json(report_spec_snapshot), + *[version_manifest.get(field) for field in REPORT_RUN_VERSION_FIELDS], + ), + ) + created_row: Row | None = tx.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (run_id,), + ).fetchone() + if created_row is None: + raise ValueError(f"Report output run #{run_id} not found after create") + return self._parse_run_row(created_row) def get_report_output_run(self, run_id: str) -> dict | None: row: Row | None = database.query( diff --git a/policyengine_api/services/report_spec_service.py b/policyengine_api/services/report_spec_service.py index b81cc566f..3a3134232 100644 --- a/policyengine_api/services/report_spec_service.py +++ b/policyengine_api/services/report_spec_service.py @@ -1,12 +1,15 @@ import json +import hashlib from typing import Any, Literal from pydantic import BaseModel, Field from sqlalchemy.engine.row import Row from policyengine_api.data import database +from policyengine_api.data.congressional_districts import normalize_us_region REPORT_SPEC_SCHEMA_VERSION = 1 +REPORT_IDENTITY_SCHEMA_VERSION = 1 REPORT_SPEC_STATUSES = {"explicit", "backfilled_assumed"} HOUSEHOLD_REPORT_KINDS = {"household_single", "household_comparison"} ECONOMY_REPORT_KINDS = {"economy_single", "economy_comparison"} @@ -48,6 +51,14 @@ def _validate_schema_version(self, schema_version: int | None) -> None: f"Unsupported report spec schema version: {schema_version}" ) + def _validate_report_identity_schema_version( + self, schema_version: int | None + ) -> None: + if schema_version != REPORT_IDENTITY_SCHEMA_VERSION: + raise ValueError( + f"Unsupported report identity schema version: {schema_version}" + ) + def _get_report_output_row(self, report_output_id: int) -> dict | None: row: Row | None = database.query( "SELECT * FROM report_outputs WHERE id = ?", @@ -211,6 +222,20 @@ def _validate_report_spec_matches_row( self, report_output: dict, report_spec: ReportSpec ) -> None: simulation_1, simulation_2 = self._get_linked_simulations(report_output) + self.validate_report_spec_matches_context( + report_output, + report_spec, + simulation_1, + simulation_2, + ) + + def validate_report_spec_matches_context( + self, + report_output: dict, + report_spec: ReportSpec, + simulation_1: dict, + simulation_2: dict | None = None, + ) -> None: inferred_report_kind = self.infer_report_kind(simulation_1, simulation_2) if report_spec.country_id != report_output["country_id"]: raise ValueError("Report spec country must match report output country") @@ -268,6 +293,17 @@ def _validate_report_spec_matches_row( "Report spec reform_policy_id must match linked simulations" ) + def parse_report_spec( + self, + raw_spec: dict, + schema_version: int = REPORT_SPEC_SCHEMA_VERSION, + ) -> ReportSpec: + self._validate_schema_version(schema_version) + report_kind = raw_spec.get("report_kind") + if report_kind is None: + raise ValueError("Report spec is missing report_kind") + return self._parse_report_spec(report_kind, raw_spec) + def infer_report_kind( self, simulation_1: dict, @@ -339,6 +375,92 @@ def _parse_json_field(self, value: str | dict | None) -> dict | None: return json.loads(value) return value + def canonicalize_report_spec_for_identity( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> dict[str, Any]: + return self.build_report_identity_document( + report_spec, + schema_version=schema_version, + ) + + def build_report_identity_document( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> dict[str, Any]: + self._validate_report_identity_schema_version(schema_version) + + identity_document: dict[str, Any] = { + "schema_version": schema_version, + "country_id": report_spec.country_id, + "report_kind": report_spec.report_kind, + "time_period": report_spec.time_period, + } + if isinstance(report_spec, HouseholdReportSpec): + identity_document["inputs"] = { + "simulation_1": report_spec.simulation_1.model_dump(), + "simulation_2": ( + report_spec.simulation_2.model_dump() + if report_spec.simulation_2 is not None + else None + ), + } + return identity_document + + region = report_spec.region + if report_spec.country_id == "us": + region = normalize_us_region(region) + identity_document["inputs"] = { + "region": region, + "baseline_policy_id": report_spec.baseline_policy_id, + "reform_policy_id": report_spec.reform_policy_id, + "dataset": report_spec.dataset, + "target": report_spec.target, + "options": report_spec.options, + } + return identity_document + + def serialize_canonical_report_spec_for_identity( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> str: + canonical_spec = self.build_report_identity_document( + report_spec, + schema_version=schema_version, + ) + return json.dumps( + canonical_spec, + sort_keys=True, + separators=(",", ":"), + ) + + def get_report_identity_hash( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> str: + canonical_json = self.serialize_canonical_report_spec_for_identity( + report_spec, + schema_version=schema_version, + ) + return hashlib.sha256(canonical_json.encode("utf-8")).hexdigest() + + def get_report_identity( + self, + report_spec: ReportSpec, + schema_version: int = REPORT_IDENTITY_SCHEMA_VERSION, + ) -> tuple[str, int]: + return ( + self.get_report_identity_hash( + report_spec, + schema_version=schema_version, + ), + schema_version, + ) + def _parse_report_spec(self, report_kind: str, raw_spec: dict) -> ReportSpec: if report_kind in HOUSEHOLD_REPORT_KINDS: return HouseholdReportSpec.model_validate(raw_spec) diff --git a/policyengine_api/services/simulation_run_service.py b/policyengine_api/services/simulation_run_service.py index 544aca9c2..5123e61e2 100644 --- a/policyengine_api/services/simulation_run_service.py +++ b/policyengine_api/services/simulation_run_service.py @@ -48,66 +48,105 @@ def create_simulation_run( simulation_spec_snapshot: dict[str, Any] | str | None = None, version_manifest: dict[str, str | None] | None = None, run_id: str | None = None, + ) -> dict: + def create_run_transaction(tx) -> dict: + return self.create_simulation_run_in_transaction( + tx, + simulation_id, + report_output_run_id=report_output_run_id, + input_position=input_position, + status=status, + trigger_type=trigger_type, + output=output, + error_message=error_message, + source_run_id=source_run_id, + simulation_spec_snapshot=simulation_spec_snapshot, + version_manifest=version_manifest, + run_id=run_id, + ) + + return database.transaction(create_run_transaction) + + def create_simulation_run_in_transaction( + self, + tx, + simulation_id: int, + report_output_run_id: str | None = None, + input_position: int | None = None, + status: str = "pending", + trigger_type: str = "initial", + output: dict[str, Any] | list[Any] | str | None = None, + error_message: str | None = None, + source_run_id: str | None = None, + simulation_spec_snapshot: dict[str, Any] | str | None = None, + version_manifest: dict[str, str | None] | None = None, + run_id: str | None = None, + requested_at: str | None = None, + started_at: str | None = None, + finished_at: str | None = None, ) -> dict: run_id = run_id or str(uuid.uuid4()) version_manifest = version_manifest or {} lock_clause = "" if database.local else " FOR UPDATE" - def create_run_transaction(tx) -> None: - parent_row: Row | None = tx.query( - f"SELECT id FROM simulations WHERE id = ?{lock_clause}", - (simulation_id,), - ).fetchone() - if parent_row is None: - raise ValueError(f"Simulation #{simulation_id} not found") - - run_sequence_row: Row | None = tx.query( - """ - SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence - FROM simulation_runs - WHERE simulation_id = ? - """, - (simulation_id,), - ).fetchone() - run_sequence = ( - int(run_sequence_row["max_run_sequence"]) + 1 - if run_sequence_row is not None - else 1 - ) - - tx.query( - f""" - INSERT INTO simulation_runs ( - id, simulation_id, report_output_run_id, input_position, run_sequence, - status, output, error_message, trigger_type, requested_at, started_at, - finished_at, source_run_id, simulation_spec_snapshot_json, - {", ".join(SIMULATION_RUN_VERSION_FIELDS)} - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - run_id, - simulation_id, - report_output_run_id, - input_position, - run_sequence, - status, - self._serialize_json(output), - error_message, - trigger_type, - None, - None, - None, - source_run_id, - self._serialize_json(simulation_spec_snapshot), - *[ - version_manifest.get(field) - for field in SIMULATION_RUN_VERSION_FIELDS - ], - ), - ) + parent_row: Row | None = tx.query( + f"SELECT id FROM simulations WHERE id = ?{lock_clause}", + (simulation_id,), + ).fetchone() + if parent_row is None: + raise ValueError(f"Simulation #{simulation_id} not found") - database.transaction(create_run_transaction) - return self.get_simulation_run(run_id) + run_sequence_row: Row | None = tx.query( + """ + SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence + FROM simulation_runs + WHERE simulation_id = ? + """, + (simulation_id,), + ).fetchone() + run_sequence = ( + int(run_sequence_row["max_run_sequence"]) + 1 + if run_sequence_row is not None + else 1 + ) + + tx.query( + f""" + INSERT INTO simulation_runs ( + id, simulation_id, report_output_run_id, input_position, run_sequence, + status, output, error_message, trigger_type, requested_at, started_at, + finished_at, source_run_id, simulation_spec_snapshot_json, + {", ".join(SIMULATION_RUN_VERSION_FIELDS)} + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + run_id, + simulation_id, + report_output_run_id, + input_position, + run_sequence, + status, + self._serialize_json(output), + error_message, + trigger_type, + requested_at, + started_at, + finished_at, + source_run_id, + self._serialize_json(simulation_spec_snapshot), + *[ + version_manifest.get(field) + for field in SIMULATION_RUN_VERSION_FIELDS + ], + ), + ) + created_row: Row | None = tx.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (run_id,), + ).fetchone() + if created_row is None: + raise ValueError(f"Simulation run #{run_id} not found after create") + return self._parse_run_row(created_row) def get_simulation_run(self, run_id: str) -> dict | None: row: Row | None = database.query( diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index e5582ee17..865a6e5e9 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime, timezone from sqlalchemy.engine.row import Row @@ -9,6 +10,7 @@ parse_json_field, serialize_json_field, ) +from policyengine_api.services.simulation_run_service import SimulationRunService from policyengine_api.services.simulation_spec_service import ( SimulationSpec, SimulationSpecService, @@ -18,10 +20,14 @@ class SimulationService: def __init__(self): self.simulation_spec_service = SimulationSpecService() + self.simulation_run_service = SimulationRunService() def _lock_clause(self) -> str: return "" if database.local else " FOR UPDATE" + def _utc_timestamp(self) -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + def _get_simulation_row( self, simulation_id: int, @@ -62,14 +68,51 @@ def _find_existing_simulation_row( ).fetchone() return dict(row) if row is not None else None - def _build_version_manifest(self, simulation: dict) -> dict[str, str | None]: - return { + def _merge_version_manifest_overrides( + self, + version_manifest: dict[str, str | None], + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + merged_manifest = dict(version_manifest) + for key, value in (version_manifest_overrides or {}).items(): + if key in merged_manifest and value is not None: + merged_manifest[key] = value + return merged_manifest + + def _build_bootstrap_version_manifest( + self, + simulation: dict, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + version_manifest = { "country_package_version": simulation.get("api_version"), "policyengine_version": None, "data_version": None, "runtime_app_name": None, "simulation_cache_version": None, } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) + + def _build_existing_run_version_manifest( + self, + run: dict, + simulation: dict, + version_manifest_overrides: dict[str, str | None] | None = None, + ) -> dict[str, str | None]: + fallback_manifest = self._build_bootstrap_version_manifest(simulation) + version_manifest = { + key: run.get(key) + if run.get(key) is not None + else fallback_manifest.get(key) + for key in fallback_manifest + } + return self._merge_version_manifest_overrides( + version_manifest, + version_manifest_overrides=version_manifest_overrides, + ) def _list_simulation_runs_descending( self, simulation_id: int, *, queryer=None @@ -93,6 +136,32 @@ def _list_simulation_runs_descending( runs.append(run) return runs + def _get_simulation_run_row( + self, + run_id: str, + *, + queryer=None, + simulation_id: int | None = None, + for_update: bool = False, + ) -> dict | None: + queryer = queryer or database + query = "SELECT * FROM simulation_runs WHERE id = ?" + params: list[str | int] = [run_id] + if simulation_id is not None: + query += " AND simulation_id = ?" + params.append(simulation_id) + if for_update: + query += self._lock_clause() + + row: Row | None = queryer.query(query, tuple(params)).fetchone() + if row is None: + return None + run = dict(row) + run["simulation_spec_snapshot_json"] = parse_json_field( + run.get("simulation_spec_snapshot_json") + ) + return run + def _select_mutable_run( self, simulation: dict, runs_descending: list[dict] ) -> dict | None: @@ -103,6 +172,23 @@ def _select_mutable_run( return run return runs_descending[0] if runs_descending else None + def _select_display_run( + self, simulation: dict, runs_descending: list[dict] + ) -> dict | None: + active_run_id = simulation.get("active_run_id") + if active_run_id is not None: + for run in runs_descending: + if run["id"] == active_run_id: + return run + + latest_successful_run_id = simulation.get("latest_successful_run_id") + if latest_successful_run_id is not None: + for run in runs_descending: + if run["id"] == latest_successful_run_id: + return run + + return runs_descending[0] if runs_descending else None + def _upsert_simulation_spec_in_transaction( self, tx, simulation: dict ) -> SimulationSpec: @@ -134,8 +220,8 @@ def _run_matches_parent( run: dict, simulation: dict, simulation_spec: SimulationSpec, + version_manifest: dict[str, str | None], ) -> bool: - version_manifest = self._build_version_manifest(simulation) return ( run["status"] == simulation["status"] and run.get("output") == simulation.get("output") @@ -152,9 +238,12 @@ def _run_matches_parent( ) def _insert_bootstrap_run( - self, tx, simulation: dict, simulation_spec: SimulationSpec + self, + tx, + simulation: dict, + simulation_spec: SimulationSpec, + version_manifest: dict[str, str | None], ) -> None: - version_manifest = self._build_version_manifest(simulation) tx.query( """ INSERT INTO simulation_runs ( @@ -194,8 +283,8 @@ def _update_simulation_run_in_transaction( run_id: str, simulation: dict, simulation_spec: SimulationSpec, + version_manifest: dict[str, str | None], ) -> None: - version_manifest = self._build_version_manifest(simulation) tx.query( """ UPDATE simulation_runs @@ -247,12 +336,135 @@ def _sync_parent_pointers_in_transaction( simulation["active_run_id"] = desired_active_run_id simulation["latest_successful_run_id"] = desired_latest_successful_run_id + def _sync_parent_mirror_from_display_run_in_transaction( + self, + tx, + simulation: dict, + runs_descending: list[dict], + ) -> dict: + self._sync_parent_pointers_in_transaction(tx, simulation, runs_descending) + display_run = self._select_display_run(simulation, runs_descending) + if display_run is None: + refreshed_simulation = self._get_simulation_row( + simulation["id"], + queryer=tx, + country_id=simulation["country_id"], + ) + if refreshed_simulation is None: + raise ValueError(f"Simulation #{simulation['id']} not found after sync") + return refreshed_simulation + + parent_api_version = ( + display_run["simulation_cache_version"] + if display_run.get("simulation_cache_version") is not None + else simulation.get("api_version") + ) + tx.query( + """ + UPDATE simulations + SET status = ?, output = ?, error_message = ?, api_version = ? + WHERE id = ? AND country_id = ? + """, + ( + display_run["status"], + serialize_json_field(display_run.get("output")), + display_run.get("error_message"), + parent_api_version, + simulation["id"], + simulation["country_id"], + ), + ) + refreshed_simulation = self._get_simulation_row( + simulation["id"], + queryer=tx, + country_id=simulation["country_id"], + ) + if refreshed_simulation is None: + raise ValueError(f"Simulation #{simulation['id']} not found after sync") + return refreshed_simulation + + def create_report_rerun_simulation_run_in_transaction( + self, + tx, + simulation: dict, + *, + report_output_run_id: str, + input_position: int, + ) -> dict: + simulation_spec = self._upsert_simulation_spec_in_transaction(tx, simulation) + runs_descending = self._list_simulation_runs_descending( + simulation["id"], + queryer=tx, + ) + source_run = self._select_display_run(simulation, runs_descending) + version_manifest = ( + self._build_existing_run_version_manifest( + source_run, + simulation, + ) + if source_run is not None + else self._build_bootstrap_version_manifest(simulation) + ) + created_run = self.simulation_run_service.create_simulation_run_in_transaction( + tx, + simulation["id"], + report_output_run_id=report_output_run_id, + input_position=input_position, + status="pending", + trigger_type="report_rerun", + source_run_id=source_run["id"] if source_run is not None else None, + simulation_spec_snapshot=simulation_spec.model_dump(), + version_manifest=version_manifest, + requested_at=self._utc_timestamp(), + ) + + simulation["status"] = "pending" + simulation["output"] = None + simulation["error_message"] = None + runs_descending = self._list_simulation_runs_descending( + simulation["id"], + queryer=tx, + ) + self._sync_parent_pointers_in_transaction(tx, simulation, runs_descending) + tx.query( + """ + UPDATE simulations + SET status = ?, output = ?, error_message = ? + WHERE id = ? AND country_id = ? + """, + ( + "pending", + None, + None, + simulation["id"], + simulation["country_id"], + ), + ) + return created_run + + def _merge_display_run_into_simulation( + self, + simulation: dict, + display_run: dict | None, + ) -> dict: + if display_run is None: + return dict(simulation) + + result = dict(simulation) + result["status"] = display_run["status"] + result["output"] = display_run.get("output") + result["error_message"] = display_run.get("error_message") + if display_run.get("simulation_cache_version") is not None: + result["api_version"] = display_run["simulation_cache_version"] + return result + def _ensure_simulation_dual_write_state_in_transaction( self, tx, simulation_id: int, *, country_id: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: simulation = self._get_simulation_row( simulation_id, @@ -268,22 +480,45 @@ def _ensure_simulation_dual_write_state_in_transaction( simulation_id, queryer=tx ) if not runs_descending: - self._insert_bootstrap_run(tx, simulation, simulation_spec) + version_manifest = self._build_bootstrap_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) + self._insert_bootstrap_run( + tx, + simulation, + simulation_spec, + version_manifest=version_manifest, + ) runs_descending = self._list_simulation_runs_descending( simulation_id, queryer=tx ) else: mutable_run = self._select_mutable_run(simulation, runs_descending) + version_manifest = ( + self._build_existing_run_version_manifest( + mutable_run, + simulation, + version_manifest_overrides=version_manifest_overrides, + ) + if mutable_run is not None + else self._build_bootstrap_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) + ) if mutable_run is not None and not self._run_matches_parent( mutable_run, simulation, simulation_spec, + version_manifest=version_manifest, ): self._update_simulation_run_in_transaction( tx, run_id=mutable_run["id"], simulation=simulation, simulation_spec=simulation_spec, + version_manifest=version_manifest, ) runs_descending = self._list_simulation_runs_descending( simulation_id, queryer=tx @@ -300,13 +535,17 @@ def _ensure_simulation_dual_write_state_in_transaction( return refreshed_simulation def ensure_simulation_dual_write_state( - self, simulation_id: int, country_id: str | None = None + self, + simulation_id: int, + country_id: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> dict: return database.transaction( lambda tx: self._ensure_simulation_dual_write_state_in_transaction( tx, simulation_id, country_id=country_id, + version_manifest_overrides=version_manifest_overrides, ) ) @@ -446,12 +685,45 @@ def get_simulation(self, country_id: str, simulation_id: int) -> dict | None: f"Invalid simulation ID: {simulation_id}. Must be a positive integer." ) - return self._get_simulation_row(simulation_id, country_id=country_id) + simulation = self._get_simulation_row(simulation_id, country_id=country_id) + if simulation is None: + return None + + runs_descending = self._list_simulation_runs_descending(simulation_id) + display_run = self._select_display_run(simulation, runs_descending) + return self._merge_display_run_into_simulation(simulation, display_run) except Exception as e: print(f"Error fetching simulation #{simulation_id}. Details: {str(e)}") raise e + def get_simulation_for_run( + self, + country_id: str, + simulation_id: int, + simulation_run_id: str, + ) -> dict | None: + """ + Get a simulation projected through one explicit run. + + Normal simulation reads intentionally apply display-run selection. PATCH + responses for an explicit run need the narrower projection so workers + see the run they just updated, even if it is not the simulation's + display run. + """ + simulation = self._get_simulation_row(simulation_id, country_id=country_id) + if simulation is None: + return None + + explicit_run = self._get_simulation_run_row( + simulation_run_id, + simulation_id=simulation_id, + ) + if explicit_run is None: + return None + + return self._merge_display_run_into_simulation(simulation, explicit_run) + def update_simulation( self, country_id: str, @@ -459,6 +731,8 @@ def update_simulation( status: str | None = None, output: str | None = None, error_message: str | None = None, + simulation_run_id: str | None = None, + version_manifest_overrides: dict[str, str | None] | None = None, ) -> bool: """ Update a simulation record with results or error. @@ -474,36 +748,15 @@ def update_simulation( bool: True if update was successful. """ print(f"Updating simulation {simulation_id}") - api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) try: - update_fields = [] - update_values = [] - - if status is not None: - update_fields.append("status = ?") - update_values.append(status) - - if output is not None: - update_fields.append("output = ?") - update_values.append(output) - - if error_message is not None: - update_fields.append("error_message = ?") - update_values.append(error_message) - - # Only refresh api_version when the caller is actually - # changing one of the user-supplied fields above. The - # previous code appended api_version unconditionally, so - # the "no fields to update" guard below never fired and a - # PATCH with an empty body still touched the row. - if not update_fields: + has_user_fields = ( + status is not None or output is not None or error_message is not None + ) + if not has_user_fields and not version_manifest_overrides: print("No fields to update") return False - update_fields.append("api_version = ?") - update_values.append(api_version) - def tx_callback(tx): simulation = self._get_simulation_row( simulation_id, @@ -514,14 +767,86 @@ def tx_callback(tx): if simulation is None: raise ValueError(f"Simulation #{simulation_id} not found") - tx.query( - f"UPDATE simulations SET {', '.join(update_fields)} WHERE id = ? AND country_id = ?", - (*update_values, simulation_id, country_id), + simulation_spec = self._upsert_simulation_spec_in_transaction( + tx, + simulation, + ) + runs_descending = self._list_simulation_runs_descending( + simulation_id, + queryer=tx, + ) + if not runs_descending: + version_manifest = self._build_bootstrap_version_manifest( + simulation, + version_manifest_overrides=version_manifest_overrides, + ) + self._insert_bootstrap_run( + tx, + simulation, + simulation_spec, + version_manifest=version_manifest, + ) + runs_descending = self._list_simulation_runs_descending( + simulation_id, + queryer=tx, + ) + + if simulation_run_id is not None: + mutable_run = self._get_simulation_run_row( + simulation_run_id, + queryer=tx, + simulation_id=simulation_id, + for_update=True, + ) + if mutable_run is None: + raise ValueError( + f"Simulation run #{simulation_run_id} not found for " + f"simulation #{simulation_id}" + ) + else: + mutable_run = self._select_mutable_run(simulation, runs_descending) + + if mutable_run is None: + raise ValueError( + "Cannot update simulation without an active simulation run" + ) + + run_update_state = dict(simulation) + run_update_state["status"] = ( + status if status is not None else mutable_run["status"] + ) + run_update_state["output"] = ( + output if output is not None else mutable_run.get("output") + ) + run_update_state["error_message"] = ( + error_message + if error_message is not None + else mutable_run.get("error_message") + ) + + version_manifest = self._build_existing_run_version_manifest( + mutable_run, + simulation, + version_manifest_overrides=version_manifest_overrides, ) - self._ensure_simulation_dual_write_state_in_transaction( + self._update_simulation_run_in_transaction( tx, + run_id=mutable_run["id"], + simulation=run_update_state, + simulation_spec=simulation_spec, + version_manifest=version_manifest, + ) + simulation["status"] = run_update_state["status"] + simulation["output"] = run_update_state["output"] + simulation["error_message"] = run_update_state["error_message"] + runs_descending = self._list_simulation_runs_descending( simulation_id, - country_id=country_id, + queryer=tx, + ) + self._sync_parent_mirror_from_display_run_in_transaction( + tx, + simulation, + runs_descending, ) database.transaction(tx_callback) diff --git a/pyproject.toml b/pyproject.toml index f5063c0c3..8004ad179 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ "License :: OSI Approved :: GNU Affero General Public License v3", ] dependencies = [ + "alembic>=1.13.0", "anthropic", "assertpy", "click>=8,<9", diff --git a/tests/unit/data/test_run_schema.py b/tests/unit/data/test_run_schema.py index 2bcba1eff..6da72f5de 100644 --- a/tests/unit/data/test_run_schema.py +++ b/tests/unit/data/test_run_schema.py @@ -8,6 +8,16 @@ def _column_names(test_db, table_name: str) -> set[str]: return {row["name"] for row in rows} +def _index_is_unique(test_db, table_name: str, index_name: str) -> bool: + rows = test_db.query(f"PRAGMA index_list({table_name})").fetchall() + return any(row["name"] == index_name and row["unique"] == 1 for row in rows) + + +def _index_exists(test_db, table_name: str, index_name: str) -> bool: + rows = test_db.query(f"PRAGMA index_list({table_name})").fetchall() + return any(row["name"] == index_name for row in rows) + + def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): report_output_columns = _column_names(test_db, "report_outputs") assert { @@ -15,6 +25,8 @@ def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): "report_spec_json", "report_spec_schema_version", "report_spec_status", + "report_identity_hash", + "report_identity_schema_version", "active_run_id", "latest_successful_run_id", }.issubset(report_output_columns) @@ -61,8 +73,24 @@ def test_stage_one_run_schema_is_initialized_in_local_test_db(test_db): "simulation_cache_version", }.issubset(simulation_run_columns) - alias_columns = _column_names(test_db, "legacy_report_output_aliases") - assert {"legacy_report_output_id", "canonical_report_output_id"} == alias_columns + id_map_columns = _column_names(test_db, "legacy_report_output_id_map") + assert { + "legacy_report_output_id", + "canonical_report_output_id", + "display_report_output_run_id", + } == id_map_columns + + legacy_alias_columns = _column_names(test_db, "legacy_report_output_aliases") + assert { + "legacy_report_output_id", + "canonical_report_output_id", + } == legacy_alias_columns + assert _index_exists(test_db, "report_outputs", "report_outputs_identity_idx") + assert not _index_is_unique( + test_db, + "report_outputs", + "report_outputs_identity_idx", + ) def test_stage_one_schema_is_defined_in_both_sql_initializers(): @@ -75,8 +103,13 @@ def test_stage_one_schema_is_defined_in_both_sql_initializers(): "CREATE TABLE IF NOT EXISTS report_output_runs", "CREATE TABLE IF NOT EXISTS simulation_runs", "CREATE TABLE IF NOT EXISTS legacy_report_output_aliases", + "CREATE TABLE IF NOT EXISTS legacy_report_output_id_map", + "display_report_output_run_id", + "CREATE INDEX report_outputs_identity_idx", "report_spec_json", "report_spec_status", + "report_identity_hash", + "report_identity_schema_version", "simulation_spec_json", "active_run_id", "latest_successful_run_id", diff --git a/tests/unit/routes/test_route_exception_handling.py b/tests/unit/routes/test_route_exception_handling.py index b6fb38a28..124c8731b 100644 --- a/tests/unit/routes/test_route_exception_handling.py +++ b/tests/unit/routes/test_route_exception_handling.py @@ -63,7 +63,7 @@ def test_simulation_create_value_error_still_400(): def test_report_create_runtime_error_becomes_500(): client = _client_with(report_output_bp) with patch( - "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output", + "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output_for_create", side_effect=RuntimeError("db went away"), ): response = client.post( @@ -76,7 +76,7 @@ def test_report_create_runtime_error_becomes_500(): def test_report_create_value_error_still_400(): client = _client_with(report_output_bp) with patch( - "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output", + "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output_for_create", side_effect=ValueError("bad input"), ): response = client.post( diff --git a/tests/unit/services/test_report_output_alias_service.py b/tests/unit/services/test_report_output_alias_service.py deleted file mode 100644 index e4e28c916..000000000 --- a/tests/unit/services/test_report_output_alias_service.py +++ /dev/null @@ -1,279 +0,0 @@ -import pytest - -from policyengine_api.services.report_output_alias_service import ( - ReportOutputAliasService, -) -from policyengine_api.services.report_output_service import ReportOutputService -from policyengine_api.services.simulation_service import SimulationService - -alias_service = ReportOutputAliasService() -report_output_service = ReportOutputService() -simulation_service = SimulationService() - - -class TestReportOutputAliasService: - def _insert_legacy_report_output( - self, - test_db, - legacy_report_output_id: int, - canonical_report: dict, - api_version: str = "legacy-version", - ) -> None: - test_db.query( - """ - INSERT INTO report_outputs ( - id, country_id, simulation_1_id, simulation_2_id, api_version, status, year - ) VALUES (?, ?, ?, ?, ?, ?, ?) - """, - ( - legacy_report_output_id, - canonical_report["country_id"], - canonical_report["simulation_1_id"], - canonical_report["simulation_2_id"], - api_version, - canonical_report["status"], - canonical_report["year"], - ), - ) - - def test_resolves_to_canonical_report_output_id_when_alias_exists(self, test_db): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_1", - population_type="household", - policy_id=1, - ) - canonical_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - self._insert_legacy_report_output(test_db, 999, canonical_report) - - alias_service.set_alias( - legacy_report_output_id=999, - canonical_report_output_id=canonical_report["id"], - ) - - resolved_id = alias_service.resolve_canonical_report_output_id(999) - - assert resolved_id == canonical_report["id"] - - def test_returns_requested_id_when_alias_is_not_needed(self, test_db): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_2", - population_type="household", - policy_id=2, - ) - report_output = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - - resolved_id = alias_service.resolve_canonical_report_output_id( - report_output["id"] - ) - - assert resolved_id == report_output["id"] - - def test_returns_none_for_unknown_report_output(self, test_db): - assert alias_service.resolve_canonical_report_output_id(123456) is None - - def test_set_alias_is_idempotent_for_same_canonical_report_output(self, test_db): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_3", - population_type="household", - policy_id=3, - ) - canonical_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - self._insert_legacy_report_output(test_db, 1001, canonical_report) - - assert ( - alias_service.set_alias( - legacy_report_output_id=1001, - canonical_report_output_id=canonical_report["id"], - ) - is True - ) - assert ( - alias_service.set_alias( - legacy_report_output_id=1001, - canonical_report_output_id=canonical_report["id"], - ) - is True - ) - - def test_rejects_alias_to_missing_canonical_report_output(self, test_db): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_3a", - population_type="household", - policy_id=3, - ) - canonical_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - self._insert_legacy_report_output(test_db, 1002, canonical_report) - - with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( - legacy_report_output_id=1002, - canonical_report_output_id=999999, - ) - - assert "Canonical report output #999999 not found" in str(exc_info.value) - - def test_rejects_conflicting_alias_remap(self, test_db): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_4", - population_type="household", - policy_id=4, - ) - canonical_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - other_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2026", - ) - self._insert_legacy_report_output(test_db, 1003, canonical_report) - alias_service.set_alias( - legacy_report_output_id=1003, - canonical_report_output_id=canonical_report["id"], - ) - - with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( - legacy_report_output_id=1003, - canonical_report_output_id=other_report["id"], - ) - - assert ( - "Legacy report output alias already points to canonical report output " - f"#{canonical_report['id']}" - ) in str(exc_info.value) - - def test_rejects_alias_when_legacy_report_output_is_missing(self, test_db): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_4a", - population_type="household", - policy_id=4, - ) - canonical_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - - with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( - legacy_report_output_id=10030, - canonical_report_output_id=canonical_report["id"], - ) - - assert "Legacy report output #10030 not found" in str(exc_info.value) - - def test_rejects_alias_when_legacy_and_canonical_reports_do_not_match( - self, test_db - ): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_4b", - population_type="household", - policy_id=4, - ) - canonical_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - mismatched_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2026", - ) - - with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( - legacy_report_output_id=mismatched_report["id"], - canonical_report_output_id=canonical_report["id"], - ) - - assert "must describe the same report" in str(exc_info.value) - - def test_rejects_alias_when_legacy_and_canonical_ids_match(self, test_db): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_4c", - population_type="household", - policy_id=4, - ) - canonical_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - - with pytest.raises(ValueError) as exc_info: - alias_service.set_alias( - legacy_report_output_id=canonical_report["id"], - canonical_report_output_id=canonical_report["id"], - ) - - assert "must be different" in str(exc_info.value) - - def test_rejects_alias_resolution_when_canonical_report_output_is_missing( - self, test_db - ): - simulation = simulation_service.create_simulation( - country_id="us", - population_id="household_5", - population_type="household", - policy_id=5, - ) - canonical_report = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", - ) - self._insert_legacy_report_output(test_db, 1004, canonical_report) - alias_service.set_alias( - legacy_report_output_id=1004, - canonical_report_output_id=canonical_report["id"], - ) - test_db.query( - "DELETE FROM report_outputs WHERE id = ?", - (canonical_report["id"],), - ) - - with pytest.raises(ValueError) as exc_info: - alias_service.resolve_canonical_report_output_id(1004) - - assert ( - f"Alias points to missing canonical report output #{canonical_report['id']}" - ) in str(exc_info.value) diff --git a/tests/unit/services/test_report_output_id_map_service.py b/tests/unit/services/test_report_output_id_map_service.py new file mode 100644 index 000000000..f1dbcc412 --- /dev/null +++ b/tests/unit/services/test_report_output_id_map_service.py @@ -0,0 +1,438 @@ +import pytest + +from policyengine_api.services.report_output_id_map_service import ( + ReportOutputIdMapService, +) +from policyengine_api.services.report_output_service import ReportOutputService +from policyengine_api.services.simulation_service import SimulationService + +id_map_service = ReportOutputIdMapService() +report_output_service = ReportOutputService() +simulation_service = SimulationService() + + +class TestReportOutputIdMapService: + def _insert_legacy_report_output( + self, + test_db, + legacy_report_output_id: int, + canonical_report: dict, + api_version: str = "legacy-version", + report_identity_hash: str | None = None, + report_identity_schema_version: int | None = None, + ) -> None: + test_db.query( + """ + INSERT INTO report_outputs ( + id, country_id, simulation_1_id, simulation_2_id, api_version, status, year, + report_identity_hash, report_identity_schema_version + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + legacy_report_output_id, + canonical_report["country_id"], + canonical_report["simulation_1_id"], + canonical_report["simulation_2_id"], + api_version, + canonical_report["status"], + canonical_report["year"], + report_identity_hash or canonical_report.get("report_identity_hash"), + report_identity_schema_version + if report_identity_schema_version is not None + else canonical_report.get("report_identity_schema_version"), + ), + ) + + def _display_run_id(self, test_db, report_output_id: int) -> str: + row = test_db.query( + """ + SELECT id FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence DESC + LIMIT 1 + """, + (report_output_id,), + ).fetchone() + assert row is not None + return row["id"] + + def test_resolves_to_canonical_report_output_id_when_mapping_exists(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_1", + population_type="household", + policy_id=1, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + self._insert_legacy_report_output(test_db, 999, canonical_report) + + id_map_service.set_mapping( + legacy_report_output_id=999, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), + ) + + resolved_id = id_map_service.resolve_canonical_report_output_id(999) + + assert resolved_id == canonical_report["id"] + + def test_returns_requested_id_when_mapping_is_not_needed(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_2", + population_type="household", + policy_id=2, + ) + report_output = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + resolved_id = id_map_service.resolve_canonical_report_output_id( + report_output["id"] + ) + + assert resolved_id == report_output["id"] + + def test_returns_none_for_unknown_report_output(self, test_db): + assert id_map_service.resolve_canonical_report_output_id(123456) is None + + def test_set_mapping_is_idempotent_for_same_canonical_report_output(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_3", + population_type="household", + policy_id=3, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + self._insert_legacy_report_output(test_db, 1001, canonical_report) + + assert ( + id_map_service.set_mapping( + legacy_report_output_id=1001, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), + ) + is True + ) + assert ( + id_map_service.set_mapping( + legacy_report_output_id=1001, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), + ) + is True + ) + + def test_rejects_mapping_to_missing_canonical_report_output(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_3a", + population_type="household", + policy_id=3, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + self._insert_legacy_report_output(test_db, 1002, canonical_report) + + with pytest.raises(ValueError) as exc_info: + id_map_service.set_mapping( + legacy_report_output_id=1002, + canonical_report_output_id=999999, + display_report_output_run_id="missing-run", + ) + + assert "Canonical report output #999999 not found" in str(exc_info.value) + + def test_rejects_mapping_to_missing_display_report_output_run(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_3b", + population_type="household", + policy_id=3, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + self._insert_legacy_report_output(test_db, 10020, canonical_report) + + with pytest.raises(ValueError) as exc_info: + id_map_service.set_mapping( + legacy_report_output_id=10020, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id="missing-run", + ) + + assert "Display report output run #missing-run not found" in str(exc_info.value) + + def test_rejects_conflicting_mapping_remap(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_4", + population_type="household", + policy_id=4, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + other_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + self._insert_legacy_report_output(test_db, 1003, canonical_report) + id_map_service.set_mapping( + legacy_report_output_id=1003, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), + ) + + with pytest.raises(ValueError) as exc_info: + id_map_service.set_mapping( + legacy_report_output_id=1003, + canonical_report_output_id=other_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, other_report["id"] + ), + ) + + assert ( + "Legacy report output ID already maps to canonical report output " + f"#{canonical_report['id']}" + ) in str(exc_info.value) + + def test_allows_mapping_when_legacy_report_output_is_missing(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_4a", + population_type="household", + policy_id=4, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + assert ( + id_map_service.set_mapping( + legacy_report_output_id=10030, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), + ) + is True + ) + + resolved = id_map_service.resolve_report_output_id(10030) + assert resolved == { + "requested_report_output_id": 10030, + "canonical_report_output_id": canonical_report["id"], + "display_report_output_run_id": self._display_run_id( + test_db, canonical_report["id"] + ), + "is_legacy_id": True, + } + + def test_rejects_mapping_when_reports_do_not_share_canonical_identity( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=34, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=35, + ) + default_report_spec = report_output_service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + cliff_report_spec = report_output_service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=default_report_spec, + report_spec_schema_version=1, + ) + distinct_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=cliff_report_spec, + report_spec_schema_version=1, + ) + + with pytest.raises(ValueError) as exc_info: + id_map_service.set_mapping( + legacy_report_output_id=distinct_report["id"], + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), + ) + + assert "must share canonical report identity" in str(exc_info.value) + + def test_rejects_mapping_when_legacy_report_output_has_no_identity(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_4b", + population_type="household", + policy_id=4, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + self._insert_legacy_report_output( + test_db, + legacy_report_output_id=10031, + canonical_report=canonical_report, + report_identity_hash=None, + report_identity_schema_version=None, + ) + test_db.query( + """ + UPDATE report_outputs + SET report_identity_hash = NULL, report_identity_schema_version = NULL + WHERE id = ? + """, + (10031,), + ) + + with pytest.raises(ValueError) as exc_info: + id_map_service.set_mapping( + legacy_report_output_id=10031, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), + ) + + assert "must have canonical report identity" in str(exc_info.value) + + def test_rejects_mapping_when_legacy_and_canonical_ids_match(self, test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_4c", + population_type="household", + policy_id=4, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + with pytest.raises(ValueError) as exc_info: + id_map_service.set_mapping( + legacy_report_output_id=canonical_report["id"], + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), + ) + + assert "must be different" in str(exc_info.value) + + def test_rejects_mapping_resolution_when_canonical_report_output_is_missing( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_5", + population_type="household", + policy_id=5, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + self._insert_legacy_report_output(test_db, 1004, canonical_report) + id_map_service.set_mapping( + legacy_report_output_id=1004, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=self._display_run_id( + test_db, canonical_report["id"] + ), + ) + test_db.query( + "DELETE FROM report_outputs WHERE id = ?", + (canonical_report["id"],), + ) + + with pytest.raises(ValueError) as exc_info: + id_map_service.resolve_canonical_report_output_id(1004) + + assert ( + "Legacy ID mapping points to missing canonical report output " + f"#{canonical_report['id']}" + ) in str(exc_info.value) diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 55ee2ff62..3de2197e4 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -3,18 +3,20 @@ from datetime import datetime, timezone from policyengine_api.constants import get_report_output_cache_version +from policyengine_api.services.report_output_id_map_service import ( + ReportOutputIdMapService, +) from policyengine_api.services.report_output_service import ReportOutputService from policyengine_api.services.report_run_service import ReportRunService from policyengine_api.services.run_sync_utils import select_display_report_run from policyengine_api.services.simulation_service import SimulationService -from tests.fixtures.services import report_output_fixtures - pytest_plugins = ("tests.fixtures.services.report_output_fixtures",) service = ReportOutputService() report_run_service = ReportRunService() simulation_service = SimulationService() +id_map_service = ReportOutputIdMapService() class TestReportOutputRunTimestamps: @@ -76,130 +78,6 @@ def test_select_display_run_uses_matching_result_before_newest_fallback(self): assert selected_run["id"] == "matching" -class TestFindExistingReportOutput: - """Test finding existing report outputs in the database.""" - - def test_find_existing_report_output_found(self, test_db, existing_report_record): - """Test finding an existing report output.""" - # GIVEN an existing report record (from fixture) - - # WHEN we search for a report with matching simulation IDs - result = service.find_existing_report_output( - country_id=existing_report_record["country_id"], - simulation_1_id=existing_report_record["simulation_1_id"], - simulation_2_id=existing_report_record["simulation_2_id"], - ) - - # THEN the result should contain the existing report - assert result is not None - assert result["id"] == existing_report_record["id"] - assert ( - result["country_id"] - == report_output_fixtures.valid_report_data["country_id"] - ) - assert result["simulation_1_id"] == existing_report_record["simulation_1_id"] - assert result["status"] == existing_report_record["status"] - - def test_find_existing_report_output_not_found(self, test_db): - """Test that None is returned when no report exists.""" - # GIVEN an empty database - - # WHEN we search for a non-existent report - result = service.find_existing_report_output( - country_id="us", - simulation_1_id=999, - simulation_2_id=888, - year="2025", - ) - - # THEN None should be returned - assert result is None - - def test_find_existing_report_output_with_null_simulation2(self, test_db): - """Test finding reports where simulation_2_id is NULL.""" - api_version = get_report_output_cache_version("us") - # GIVEN a report with NULL simulation_2_id - test_db.query( - "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 100, None, "complete", api_version, "2025"), - ) - - # WHEN we search for it - result = service.find_existing_report_output( - country_id="us", - simulation_1_id=100, - simulation_2_id=None, - year="2025", - ) - - # THEN we should find it - assert result is not None - assert result["simulation_1_id"] == 100 - assert result["simulation_2_id"] is None - assert result["year"] == "2025" - - def test_find_existing_report_output_with_year(self, test_db): - """Test finding reports with different years.""" - api_version = get_report_output_cache_version("us") - # GIVEN reports with different years for the same simulation - test_db.query( - "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 101, None, "complete", api_version, "2025"), - ) - test_db.query( - "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 101, None, "complete", api_version, "2024"), - ) - - # WHEN we search for the 2025 report - result_2025 = service.find_existing_report_output( - country_id="us", - simulation_1_id=101, - simulation_2_id=None, - year="2025", - ) - - # THEN we should find the 2025 report - assert result_2025 is not None - assert result_2025["simulation_1_id"] == 101 - assert result_2025["year"] == "2025" - - # WHEN we search for the 2024 report - result_2024 = service.find_existing_report_output( - country_id="us", - simulation_1_id=101, - simulation_2_id=None, - year="2024", - ) - - # THEN we should find the 2024 report - assert result_2024 is not None - assert result_2024["simulation_1_id"] == 101 - assert result_2024["year"] == "2024" - - # AND the two reports should have different IDs - assert result_2025["id"] != result_2024["id"] - - def test_find_existing_report_output_ignores_stale_runtime_version(self, test_db): - current_version = get_report_output_cache_version("us") - stale_version = "r0stale1" - assert stale_version != current_version - - test_db.query( - "INSERT INTO report_outputs (country_id, simulation_1_id, simulation_2_id, status, api_version, year) VALUES (?, ?, ?, ?, ?, ?)", - ("us", 102, None, "complete", stale_version, "2025"), - ) - - result = service.find_existing_report_output( - country_id="us", - simulation_1_id=102, - simulation_2_id=None, - year="2025", - ) - - assert result is None - - class TestCreateReportOutput: """Test creating new report outputs in the database.""" @@ -488,20 +366,313 @@ def test_create_report_output_populates_economy_comparison_report_spec( if isinstance(report_spec, str): report_spec = json.loads(report_spec) assert report_spec["region"] == "state/ca" - assert report_spec["baseline_policy_id"] == 30 - assert report_spec["reform_policy_id"] == 31 - assert report_spec["dataset"] == "default" - run = test_db.query( - "SELECT * FROM report_output_runs WHERE report_output_id = ?", - (created_report["id"],), + def test_create_report_output_reuses_same_explicit_economy_spec(self, test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=32, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=33, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ny", + "baseline_policy_id": 32, + "reform_policy_id": 33, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + + first_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + second_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + assert first_report["id"] == second_report["id"] + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (first_report["id"],), ).fetchone() - assert run is not None - snapshot = run["report_spec_snapshot_json"] - if isinstance(snapshot, str): - snapshot = json.loads(snapshot) - assert snapshot["report_kind"] == "economy_comparison" - assert snapshot["region"] == "state/ca" + assert stored_report["report_identity_hash"] is not None + assert stored_report["report_identity_schema_version"] == 1 + + def test_create_report_output_distinguishes_explicit_economy_specs_by_identity( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=34, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/tx", + population_type="geography", + policy_id=35, + ) + default_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "default", + "target": "general", + "options": {}, + } + ) + cliff_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/tx", + "baseline_policy_id": 34, + "reform_policy_id": 35, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + + first_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=default_report_spec, + report_spec_schema_version=1, + ) + second_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=cliff_report_spec, + report_spec_schema_version=1, + ) + + assert first_report["id"] != second_report["id"] + stored_reports = test_db.query( + """ + SELECT id, report_identity_hash, report_spec_json + FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND simulation_2_id = ? AND year = ? + ORDER BY id + """, + ( + "us", + baseline_simulation["id"], + reform_simulation["id"], + "2026", + ), + ).fetchall() + assert len(stored_reports) == 2 + assert ( + stored_reports[0]["report_identity_hash"] + != stored_reports[1]["report_identity_hash"] + ) + + def test_create_report_output_loads_exact_inserted_row_for_explicit_spec( + self, test_db, monkeypatch + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ks", + population_type="geography", + policy_id=39, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ks", + population_type="geography", + policy_id=40, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ks", + "baseline_policy_id": 39, + "reform_policy_id": 40, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + + def fail_legacy_key_lookup(**_kwargs): + raise AssertionError("create should load the inserted row by primary key") + + monkeypatch.setattr( + service, + "_find_existing_report_output_row", + fail_legacy_key_lookup, + ) + + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + assert created_report["simulation_1_id"] == baseline_simulation["id"] + assert created_report["simulation_2_id"] == reform_simulation["id"] + assert created_report["report_identity_hash"] is not None + + def test_create_report_output_reuses_stale_report_and_adds_current_run( + self, test_db + ): + stale_version = "r0stale1" + current_version = get_report_output_cache_version("us") + assert stale_version != current_version + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_create_stale_runtime", + population_type="household", + policy_id=41, + ) + test_db.query( + """ + INSERT INTO report_outputs + (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "us", + simulation["id"], + None, + "complete", + json.dumps({"result": "stale"}), + stale_version, + "2026", + ), + ) + stale_report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + result = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + assert result["id"] == stale_report["id"] + assert result["status"] == "pending" + assert result["output"] is None + assert result["api_version"] == current_version + + report_rows = test_db.query( + """ + SELECT * FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND year = ? + """, + ("us", simulation["id"], "2026"), + ).fetchall() + assert len(report_rows) == 1 + + runs = test_db.query( + """ + SELECT * FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence ASC + """, + (stale_report["id"],), + ).fetchall() + assert len(runs) == 2 + assert runs[0]["status"] == "complete" + assert runs[0]["output"] == json.dumps({"result": "stale"}) + assert runs[0]["report_cache_version"] == stale_version + assert runs[1]["status"] == "pending" + assert runs[1]["output"] is None + assert runs[1]["report_cache_version"] == current_version + assert runs[1]["source_run_id"] == runs[0]["id"] + + def test_find_existing_for_create_validates_explicit_spec_context_before_reuse( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ia", + population_type="geography", + policy_id=36, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ia", + population_type="geography", + policy_id=37, + ) + mismatched_baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ia", + population_type="geography", + policy_id=38, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ia", + "baseline_policy_id": 36, + "reform_policy_id": 37, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + with pytest.raises( + ValueError, match="Report spec baseline_policy_id must match" + ): + service.find_existing_report_output_for_create( + country_id="us", + simulation_1_id=mismatched_baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + ) class TestGetReportOutput: @@ -929,7 +1100,12 @@ def test_get_report_output_does_not_rewrite_terminal_active_run_for_running_pare "SELECT * FROM report_output_runs WHERE id = ?", (successful_run_id,), ).fetchone() - assert result["status"] == "running" + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + assert result["status"] == "complete" + assert stored_report["status"] == "running" assert successful_run["status"] == "complete" assert successful_run["output"] == output_json assert successful_run["finished_at"] is not None @@ -1155,106 +1331,113 @@ def test_get_report_output_bootstraps_running_legacy_run_started_at(self, test_d assert run["started_at"] is not None assert run["finished_at"] is None - def test_find_existing_report_output_backfills_missing_timestamps(self, test_db): + def test_get_report_output_uses_selected_display_run_for_canonical_parent( + self, test_db + ): simulation = simulation_service.create_simulation( country_id="us", - population_id="household_report_legacy_timestamp_find", + population_id="household_display_run", population_type="household", - policy_id=50, + policy_id=5, ) - report = service.create_report_output( + report_output = service.create_report_output( country_id="us", simulation_1_id=simulation["id"], simulation_2_id=None, year="2025", ) + service.update_report_output( + country_id="us", + report_id=report_output["id"], + status="complete", + output=json.dumps({"budget": {"budgetary_impact": 2}}), + ) test_db.query( """ - UPDATE report_output_runs - SET requested_at = NULL - WHERE report_output_id = ? + UPDATE report_outputs + SET status = ?, output = ?, api_version = ? + WHERE id = ? """, - (report["id"],), + ( + "pending", + None, + "r0stale1", + report_output["id"], + ), ) - result = service.find_existing_report_output( + result = service.get_report_output( + country_id="us", report_output_id=report_output["id"] + ) + + assert result is not None + assert result["id"] == report_output["id"] + assert result["status"] == "complete" + assert result["output"] == json.dumps({"budget": {"budgetary_impact": 2}}) + assert result["api_version"] == get_report_output_cache_version("us") + + def test_get_report_output_resolves_legacy_id_to_canonical_display_run( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_alias_display_run", + population_type="household", + policy_id=6, + ) + canonical_report = service.create_report_output( country_id="us", simulation_1_id=simulation["id"], simulation_2_id=None, year="2025", ) - - assert result is not None - assert result["requested_at"] is not None - - def test_get_report_output_resolves_stale_id_to_current_runtime_row(self, test_db): - stale_output = { - "budget": {"budgetary_impact": 1}, - "congressional_district_impact": { - "districts": [ - { - "district": "AL-01", - "average_household_income_change": 120, - "relative_household_income_change": 0.01, - } - ] - }, - } - test_db.query( - """INSERT INTO report_outputs - (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) - VALUES (?, ?, ?, ?, ?, ?, ?)""", - ( - "us", - 2, - None, - "complete", - json.dumps(stale_output), - "r0stale1", - "2025", - ), + service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"budget": {"budgetary_impact": 3}}), ) - - stale_record = test_db.query( - "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + canonical_run = test_db.query( + """ + SELECT * FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence DESC + LIMIT 1 + """, + (canonical_report["id"],), ).fetchone() - - current_version = get_report_output_cache_version("us") - test_db.query( - """INSERT INTO report_outputs - (country_id, simulation_1_id, simulation_2_id, status, output, api_version, year) - VALUES (?, ?, ?, ?, ?, ?, ?)""", - ( - "us", - 2, - None, - "complete", - json.dumps({"budget": {"budgetary_impact": 2}}), - current_version, - "2025", - ), + legacy_run = report_run_service.create_report_output_run( + canonical_report["id"], + status="error", + trigger_type="backfill", + output=json.dumps({"budget": {"budgetary_impact": -1}}), + error_message="legacy error", + ) + id_map_service.set_mapping( + legacy_report_output_id=999, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=legacy_run["id"], ) - current_record = test_db.query( - "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" - ).fetchone() + result = service.get_report_output(country_id="us", report_output_id=999) - result = service.get_report_output( - country_id="us", report_output_id=stale_record["id"] - ) assert result is not None - assert result["id"] == stale_record["id"] - assert result["api_version"] == current_record["api_version"] - assert result["output"] == current_record["output"] + assert result["id"] == 999 + assert result["status"] == "error" + assert result["output"] == json.dumps({"budget": {"budgetary_impact": -1}}) + assert result["error_message"] == "legacy error" + assert result["api_version"] == get_report_output_cache_version("us") + assert canonical_run["id"] != legacy_run["id"] - def test_get_report_output_creates_current_runtime_row_for_stale_id(self, test_db): + def test_get_report_output_does_not_create_current_runtime_row_for_stale_id( + self, test_db + ): stale_version = "r0stale1" - current_version = get_report_output_cache_version("us") simulation = simulation_service.create_simulation( country_id="us", - population_id="household_stale_runtime_create", + population_id="household_stale_runtime_read", population_type="household", - policy_id=5, + policy_id=7, ) test_db.query( @@ -1274,17 +1457,15 @@ def test_get_report_output_creates_current_runtime_row_for_stale_id(self, test_d assert result is not None assert result["id"] == stale_record["id"] - assert result["api_version"] == current_version - assert result["status"] == "pending" + assert result["api_version"] == stale_version + assert result["status"] == "complete" assert result["output"] is None - current_rows = test_db.query( + rows = test_db.query( "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ? ORDER BY id ASC", ("us", simulation["id"], "2025"), ).fetchall() - assert len(current_rows) == 2 - assert current_rows[0]["api_version"] == stale_version - assert current_rows[1]["api_version"] == current_version + assert len(rows) == 1 def test_get_report_output_invalid_id(self, test_db): """Test that invalid report IDs are handled properly.""" @@ -1465,6 +1646,163 @@ def test_update_report_output_updates_dual_write_state(self, test_db): assert run["output"] == output_json assert run["id"] == stored_report["latest_successful_run_id"] + def test_update_report_output_preserves_stored_explicit_report_spec(self, test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/co", + population_type="geography", + policy_id=61, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/co", + population_type="geography", + policy_id=62, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/co", + "baseline_policy_id": 61, + "reform_policy_id": 62, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + success = service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "ok"}), + ) + + assert success is True + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (created_report["id"],), + ).fetchone() + assert stored_report["report_spec_status"] == "explicit" + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + def test_update_report_output_preserves_existing_run_metadata_without_overrides( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/az", + population_type="geography", + policy_id=63, + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.621.0", + "policyengine_version": "0.95.0", + "data_version": "2026.04.17", + "runtime_app_name": "policyengine-app-v2", + "resolved_dataset": "enhanced_us_household", + }, + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="error", + error_message="later failure", + ) + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + assert run["status"] == "error" + assert run["error_message"] == "later failure" + assert run["country_package_version"] == "1.621.0" + assert run["policyengine_version"] == "0.95.0" + assert run["data_version"] == "2026.04.17" + assert run["runtime_app_name"] == "policyengine-app-v2" + assert run["resolved_dataset"] == "enhanced_us_household" + + def test_update_report_output_allows_explicit_metadata_override_on_existing_run( + self, test_db + ): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nm", + population_type="geography", + policy_id=64, + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.621.0", + "policyengine_version": "0.95.0", + }, + ) + + service.update_report_output( + country_id="us", + report_id=created_report["id"], + version_manifest_overrides={ + "policyengine_version": "0.95.1", + }, + ) + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + assert run["country_package_version"] == "1.621.0" + assert run["policyengine_version"] == "0.95.1" + def test_update_report_output_bootstraps_missing_run_state(self, test_db): simulation_1 = simulation_service.create_simulation( country_id="us", @@ -1731,7 +2069,7 @@ def test_create_report_output_rolls_back_parent_insert_on_dual_write_failure( policy_id=34, ) - def fail_dual_write(tx, report_output_id, *, country_id=None): + def fail_dual_write(tx, report_output_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( @@ -1757,7 +2095,7 @@ def fail_dual_write(tx, report_output_id, *, country_id=None): ).fetchall() assert rows == [] - def test_update_report_output_rolls_back_parent_update_on_dual_write_failure( + def test_update_report_output_rolls_back_parent_update_on_run_write_failure( self, test_db, monkeypatch ): simulation = simulation_service.create_simulation( @@ -1773,16 +2111,16 @@ def test_update_report_output_rolls_back_parent_update_on_dual_write_failure( year="2025", ) - def fail_dual_write(tx, report_output_id, *, country_id=None): - raise RuntimeError("dual write sync failed") + def fail_run_update(*args, **kwargs): + raise RuntimeError("run update failed") monkeypatch.setattr( service, - "_ensure_report_output_dual_write_state_in_transaction", - fail_dual_write, + "_update_report_run_in_transaction", + fail_run_update, ) - with pytest.raises(RuntimeError, match="dual write sync failed"): + with pytest.raises(RuntimeError, match="run update failed"): service.update_report_output( country_id="us", report_id=created_report["id"], @@ -1895,3 +2233,146 @@ def test_ensure_report_output_dual_write_state_bootstraps_linked_simulations( ).fetchone() assert simulation_1_run is not None assert simulation_2_run is not None + + def test_ensure_report_output_dual_write_state_reuses_stored_explicit_report_spec( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/il", + population_type="geography", + policy_id=63, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/il", + population_type="geography", + policy_id=64, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/il", + "baseline_policy_id": 63, + "reform_policy_id": 64, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + synced_report = service.ensure_report_output_dual_write_state( + created_report["id"], + country_id="us", + ) + + assert synced_report["report_spec_status"] == "explicit" + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (created_report["id"],), + ).fetchone() + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + def test_update_report_output_invalid_stored_explicit_report_spec_fails_closed( + self, test_db + ): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/mi", + population_type="geography", + policy_id=65, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/mi", + population_type="geography", + policy_id=66, + ) + explicit_report_spec = service.parse_report_spec_payload( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/mi", + "baseline_policy_id": 65, + "reform_policy_id": 66, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + created_report = service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + report_spec=explicit_report_spec, + report_spec_schema_version=1, + ) + + corrupted_spec = { + **explicit_report_spec.model_dump(), + "region": "state/ca", + } + test_db.query( + """ + UPDATE report_outputs + SET report_spec_json = ? + WHERE id = ? + """, + ( + json.dumps(corrupted_spec), + created_report["id"], + ), + ) + + with pytest.raises( + ValueError, match="Report spec region must match linked simulations" + ): + service.update_report_output( + country_id="us", + report_id=created_report["id"], + status="complete", + output=json.dumps({"result": "should_rollback"}), + ) + + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (created_report["id"],), + ).fetchone() + assert stored_report["status"] == "pending" + assert stored_report["output"] is None + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (created_report["id"],), + ).fetchone() + assert run is not None + assert run["status"] == "pending" + assert run["output"] is None diff --git a/tests/unit/services/test_report_spec_service.py b/tests/unit/services/test_report_spec_service.py index f924df8db..4cc0259e2 100644 --- a/tests/unit/services/test_report_spec_service.py +++ b/tests/unit/services/test_report_spec_service.py @@ -5,6 +5,7 @@ from policyengine_api.services.report_spec_service import ( EconomyReportSpec, HouseholdReportSpec, + REPORT_IDENTITY_SCHEMA_VERSION, ReportSpecService, ) from policyengine_api.services.simulation_service import SimulationService @@ -89,12 +90,12 @@ def test_raises_for_mixed_population_types(self, test_db): population_type="geography", policy_id=2, ) - report_output = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation_1["id"], - simulation_2_id=simulation_2["id"], - year="2025", - ) + report_output = { + "country_id": "us", + "simulation_1_id": simulation_1["id"], + "simulation_2_id": simulation_2["id"], + "year": "2025", + } with pytest.raises(ValueError) as exc_info: report_spec_service.build_report_spec( @@ -116,12 +117,12 @@ def test_raises_for_mismatched_household_ids(self, test_db): population_type="household", policy_id=2, ) - report_output = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation_1["id"], - simulation_2_id=simulation_2["id"], - year="2025", - ) + report_output = { + "country_id": "us", + "simulation_1_id": simulation_1["id"], + "simulation_2_id": simulation_2["id"], + "year": "2025", + } with pytest.raises(ValueError) as exc_info: report_spec_service.build_report_spec( @@ -143,12 +144,12 @@ def test_raises_for_mismatched_geography_ids(self, test_db): population_type="geography", policy_id=11, ) - report_output = report_output_service.create_report_output( - country_id="us", - simulation_1_id=simulation_1["id"], - simulation_2_id=simulation_2["id"], - year="2027", - ) + report_output = { + "country_id": "us", + "simulation_1_id": simulation_1["id"], + "simulation_2_id": simulation_2["id"], + "year": "2027", + } with pytest.raises(ValueError) as exc_info: report_spec_service.build_report_spec( @@ -467,3 +468,219 @@ def test_rejects_unsupported_schema_version_on_read(self, test_db): report_spec_service.get_report_spec(report_output["id"]) assert "Unsupported report spec schema version" in str(exc_info.value) + + +class TestReportIdentity: + def test_builds_household_identity_document(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_comparison", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 2, + }, + } + ) + + identity_document = report_spec_service.build_report_identity_document( + report_spec + ) + + assert identity_document == { + "schema_version": REPORT_IDENTITY_SCHEMA_VERSION, + "country_id": "uk", + "report_kind": "household_comparison", + "time_period": "2027", + "inputs": { + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 2, + }, + }, + } + + def test_builds_economy_identity_document_with_normalized_region(self): + report_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + } + ) + + identity_document = report_spec_service.build_report_identity_document( + report_spec + ) + + assert identity_document == { + "schema_version": REPORT_IDENTITY_SCHEMA_VERSION, + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "inputs": { + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + } + + def test_canonicalize_helper_returns_identity_document(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_single", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": None, + } + ) + + assert report_spec_service.canonicalize_report_spec_for_identity( + report_spec + ) == report_spec_service.build_report_identity_document(report_spec) + + def test_equal_specs_produce_equal_hashes_despite_nested_options_key_order(self): + first_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {"outer": {"b": 2, "a": 1}, "enabled": True}, + } + ) + second_spec = EconomyReportSpec.model_validate( + { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {"enabled": True, "outer": {"a": 1, "b": 2}}, + } + ) + + assert report_spec_service.get_report_identity_hash( + first_spec + ) == report_spec_service.get_report_identity_hash(second_spec) + + @pytest.mark.parametrize( + ("field_name", "replacement_value"), + [ + ("time_period", "2028"), + ("region", "state/ny"), + ("baseline_policy_id", 12), + ("reform_policy_id", 13), + ("dataset", "enhanced_us_household"), + ("target", "cliff"), + ("options", {"view": "tax"}), + ], + ) + def test_distinct_economy_definition_fields_change_identity_hash( + self, + field_name, + replacement_value, + ): + base_spec_data = { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2027", + "region": "state/ca", + "baseline_policy_id": 10, + "reform_policy_id": 11, + "dataset": "default", + "target": "general", + "options": {}, + } + first_spec = EconomyReportSpec.model_validate( + base_spec_data + ) + second_spec = EconomyReportSpec.model_validate( + { + **base_spec_data, + field_name: replacement_value, + } + ) + + assert report_spec_service.get_report_identity_hash( + first_spec + ) != report_spec_service.get_report_identity_hash(second_spec) + + def test_report_identity_returns_hash_and_schema_version(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_single", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": None, + } + ) + + report_identity_hash, schema_version = report_spec_service.get_report_identity( + report_spec + ) + + assert len(report_identity_hash) == 64 + assert schema_version == REPORT_IDENTITY_SCHEMA_VERSION + + def test_rejects_unsupported_identity_schema_version(self): + report_spec = HouseholdReportSpec.model_validate( + { + "country_id": "uk", + "report_kind": "household_single", + "time_period": "2027", + "simulation_1": { + "population_type": "household", + "population_id": "household_1", + "policy_id": 1, + }, + "simulation_2": None, + } + ) + + with pytest.raises(ValueError) as exc_info: + report_spec_service.get_report_identity_hash( + report_spec, + schema_version=2, + ) + + assert "Unsupported report identity schema version" in str(exc_info.value) diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index 34116287f..fc53504dc 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -227,7 +227,7 @@ def test_create_simulation_reuses_existing_row_and_bootstraps_dual_write( def test_create_simulation_rolls_back_parent_insert_on_dual_write_failure( self, test_db, monkeypatch ): - def fail_dual_write(tx, simulation_id, *, country_id=None): + def fail_dual_write(tx, simulation_id, *, country_id=None, **kwargs): raise RuntimeError("dual write sync failed") monkeypatch.setattr( @@ -444,7 +444,7 @@ def test_update_simulation_does_not_append_extra_run_for_legacy_patch_traffic( assert runs[0]["id"] == first_run["id"] assert runs[0]["status"] == "complete" - def test_update_simulation_rolls_back_parent_update_on_dual_write_failure( + def test_update_simulation_rolls_back_parent_update_on_run_write_failure( self, test_db, monkeypatch ): created_simulation = service.create_simulation( @@ -454,16 +454,16 @@ def test_update_simulation_rolls_back_parent_update_on_dual_write_failure( policy_id=15, ) - def fail_dual_write(tx, simulation_id, *, country_id=None): - raise RuntimeError("dual write sync failed") + def fail_run_update(*args, **kwargs): + raise RuntimeError("run update failed") monkeypatch.setattr( service, - "_ensure_simulation_dual_write_state_in_transaction", - fail_dual_write, + "_update_simulation_run_in_transaction", + fail_run_update, ) - with pytest.raises(RuntimeError, match="dual write sync failed"): + with pytest.raises(RuntimeError, match="run update failed"): service.update_simulation( country_id="us", simulation_id=created_simulation["id"], @@ -520,3 +520,111 @@ def test_update_simulation_with_no_user_fields_returns_false(self, test_db): ).fetchone() assert post_row["api_version"] == pre_row["api_version"] assert post_row["status"] == pre_row["status"] + + def test_update_simulation_preserves_existing_run_metadata_without_overrides( + self, test_db + ): + created_simulation = service.create_simulation( + country_id="us", + population_id="household_metadata_preserve", + population_type="household", + policy_id=16, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + "data_version": "2026.04.16", + "runtime_app_name": "policyengine-app-v2", + }, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="error", + error_message="later failure", + ) + + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (created_simulation["id"],), + ).fetchone() + assert run["status"] == "error" + assert run["error_message"] == "later failure" + assert run["country_package_version"] == "1.620.0" + assert run["policyengine_version"] == "0.94.2" + assert run["data_version"] == "2026.04.16" + assert run["runtime_app_name"] == "policyengine-app-v2" + + def test_update_simulation_backfills_null_existing_run_metadata_from_parent( + self, test_db + ): + created_simulation = service.create_simulation( + country_id="us", + population_id="household_metadata_null_backfill", + population_type="household", + policy_id=17, + ) + test_db.query( + """ + UPDATE simulation_runs + SET country_package_version = NULL + WHERE simulation_id = ? + """, + (created_simulation["id"],), + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="complete", + output=json.dumps({"result": "ok"}), + ) + + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (created_simulation["id"],), + ).fetchone() + assert run["country_package_version"] == created_simulation["api_version"] + + def test_update_simulation_allows_explicit_metadata_override_on_existing_run( + self, test_db + ): + created_simulation = service.create_simulation( + country_id="us", + population_id="household_metadata_override", + population_type="household", + policy_id=18, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + status="complete", + output=json.dumps({"result": "ok"}), + version_manifest_overrides={ + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + }, + ) + + service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + version_manifest_overrides={ + "policyengine_version": "0.95.0", + }, + ) + + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (created_simulation["id"],), + ).fetchone() + assert run["country_package_version"] == "1.620.0" + assert run["policyengine_version"] == "0.95.0" diff --git a/tests/unit/test_stage5_routes.py b/tests/unit/test_stage5_routes.py index ec9f34e1b..054f2b8fd 100644 --- a/tests/unit/test_stage5_routes.py +++ b/tests/unit/test_stage5_routes.py @@ -5,14 +5,20 @@ from policyengine_api.constants import get_report_output_cache_version from policyengine_api.routes.report_output_routes import report_output_bp from policyengine_api.routes.simulation_routes import simulation_bp +from policyengine_api.services.report_output_id_map_service import ( + ReportOutputIdMapService, +) from policyengine_api.services.report_output_service import ReportOutputService from policyengine_api.services.report_run_service import ReportRunService +from policyengine_api.services.simulation_run_service import SimulationRunService from policyengine_api.services.simulation_service import SimulationService simulation_service = SimulationService() report_output_service = ReportOutputService() report_run_service = ReportRunService() +report_output_id_map_service = ReportOutputIdMapService() +simulation_run_service = SimulationRunService() def create_test_client() -> Flask: @@ -23,6 +29,20 @@ def create_test_client() -> Flask: return app.test_client() +def get_display_report_run_id(test_db, report_output_id: int) -> str: + row = test_db.query( + """ + SELECT id FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence DESC + LIMIT 1 + """, + (report_output_id,), + ).fetchone() + assert row is not None + return row["id"] + + def test_create_simulation_existing_row_repairs_dual_write_state(test_db): test_db.query( """INSERT INTO simulations @@ -122,6 +142,76 @@ def test_create_report_output_existing_row_repairs_dual_write_state(test_db): assert snapshot["report_kind"] == "household_single" +def test_create_report_output_existing_stale_row_adds_current_run(test_db): + stale_version = "r0stale1" + current_version = get_report_output_cache_version("us") + assert stale_version != current_version + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_stale_report_create", + population_type="household", + policy_id=42, + ) + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, output, year + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "us", + simulation["id"], + None, + stale_version, + "complete", + json.dumps({"result": "stale"}), + "2026", + ), + ) + stale_report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + client = create_test_client() + response = client.post( + "/us/report", + json={ + "simulation_1_id": simulation["id"], + "year": "2026", + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == stale_report["id"] + assert payload["result"]["status"] == "pending" + assert payload["result"]["output"] is None + assert payload["result"]["api_version"] == current_version + + report_rows = test_db.query( + """ + SELECT * FROM report_outputs + WHERE country_id = ? AND simulation_1_id = ? AND year = ? + """, + ("us", simulation["id"], "2026"), + ).fetchall() + assert len(report_rows) == 1 + + runs = test_db.query( + """ + SELECT * FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence ASC + """, + (stale_report["id"],), + ).fetchall() + assert len(runs) == 2 + assert runs[0]["report_cache_version"] == stale_version + assert runs[0]["output"] == json.dumps({"result": "stale"}) + assert runs[1]["report_cache_version"] == current_version + assert runs[1]["status"] == "pending" + + def test_post_report_output_returns_timestamp_fields_for_new_and_existing_report( test_db, ): @@ -167,6 +257,280 @@ def test_post_report_output_returns_timestamp_fields_for_new_and_existing_report assert existing_report["finished_at"] is None +def test_create_report_output_with_explicit_spec_persists_it(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=45, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ny", + population_type="geography", + policy_id=46, + ) + + client = create_test_client() + response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ny", + "baseline_policy_id": 45, + "reform_policy_id": 46, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + + assert response.status_code == 201 + report_id = response.get_json()["result"]["id"] + + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_id,), + ).fetchone() + assert stored_report["report_kind"] == "economy_comparison" + assert stored_report["report_spec_schema_version"] == 1 + assert stored_report["report_spec_status"] == "explicit" + + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + assert stored_report["report_identity_hash"] is not None + assert stored_report["report_identity_schema_version"] == 1 + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_id,), + ).fetchone() + assert run is not None + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + +def test_create_report_output_same_explicit_spec_returns_existing_row(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/va", + population_type="geography", + policy_id=53, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/va", + population_type="geography", + policy_id=54, + ) + payload = { + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/va", + "baseline_policy_id": 53, + "reform_policy_id": 54, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + } + + client = create_test_client() + first_response = client.post("/us/report", json=payload) + second_response = client.post("/us/report", json=payload) + + assert first_response.status_code == 201 + assert second_response.status_code == 200 + assert ( + first_response.get_json()["result"]["id"] + == second_response.get_json()["result"]["id"] + ) + + +def test_create_report_output_same_identity_after_cache_version_change_reuses_row( + test_db, monkeypatch +): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_cache_version_reuse", + population_type="household", + policy_id=75, + ) + client = create_test_client() + payload = { + "simulation_1_id": simulation["id"], + "simulation_2_id": None, + "year": "2026", + } + + first_response = client.post("/us/report", json=payload) + monkeypatch.setattr( + "policyengine_api.services.report_output_service.get_report_output_cache_version", + lambda country_id: f"{country_id}-new-report-cache-version", + ) + second_response = client.post("/us/report", json=payload) + + assert first_response.status_code == 201 + assert second_response.status_code == 200 + assert ( + first_response.get_json()["result"]["id"] + == second_response.get_json()["result"]["id"] + ) + + rows = test_db.query("SELECT * FROM report_outputs").fetchall() + assert len(rows) == 1 + + +def test_create_report_output_distinct_explicit_specs_create_distinct_rows(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/md", + population_type="geography", + policy_id=55, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/md", + population_type="geography", + policy_id=56, + ) + + client = create_test_client() + default_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/md", + "baseline_policy_id": 55, + "reform_policy_id": 56, + "dataset": "default", + "target": "general", + "options": {}, + }, + }, + ) + cliff_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/md", + "baseline_policy_id": 55, + "reform_policy_id": 56, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + + assert default_response.status_code == 201 + assert cliff_response.status_code == 201 + assert ( + default_response.get_json()["result"]["id"] + != cliff_response.get_json()["result"]["id"] + ) + + +def test_create_report_output_explicit_spec_validates_requested_simulations_before_reuse( + test_db, +): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ma", + population_type="geography", + policy_id=70, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ma", + population_type="geography", + policy_id=71, + ) + mismatched_baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/ma", + population_type="geography", + policy_id=72, + ) + payload = { + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/ma", + "baseline_policy_id": 70, + "reform_policy_id": 71, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + } + + client = create_test_client() + create_response = client.post("/us/report", json=payload) + missing_response = client.post( + "/us/report", + json={ + **payload, + "simulation_1_id": 999999, + }, + ) + mismatched_response = client.post( + "/us/report", + json={ + **payload, + "simulation_1_id": mismatched_baseline_simulation["id"], + }, + ) + + assert create_response.status_code == 201 + assert missing_response.status_code == 400 + assert mismatched_response.status_code == 400 + + report_rows = test_db.query("SELECT * FROM report_outputs").fetchall() + assert len(report_rows) == 1 + + def test_create_report_output_missing_primary_simulation_returns_bad_request(test_db): client = create_test_client() response = client.post( @@ -258,140 +622,410 @@ def test_patch_simulation_wrong_country_returns_not_found_and_does_not_mutate(te assert stored_simulation["output"] is None -def test_get_report_output_wrong_country_returns_not_found(test_db): - test_db.query( - """ - INSERT INTO report_outputs ( - country_id, simulation_1_id, simulation_2_id, api_version, status, year - ) VALUES (?, ?, ?, ?, ?, ?) - """, - ("us", 999, None, get_report_output_cache_version("us"), "pending", "2025"), - ) - report_output = test_db.query( - "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" - ).fetchone() - - client = create_test_client() - response = client.get(f"/uk/report/{report_output['id']}") - - assert response.status_code == 404 - - -def test_patch_report_output_wrong_country_returns_not_found_and_does_not_mutate( - test_db, -): - test_db.query( - """ - INSERT INTO report_outputs ( - country_id, simulation_1_id, simulation_2_id, api_version, status, year - ) VALUES (?, ?, ?, ?, ?, ?) - """, - ("us", 1000, None, get_report_output_cache_version("us"), "pending", "2025"), +def test_patch_simulation_persists_run_metadata_fields(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_metadata", + population_type="household", + policy_id=47, ) - report_output = test_db.query( - "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" - ).fetchone() client = create_test_client() response = client.patch( - "/uk/report", + "/us/simulation", json={ - "id": report_output["id"], + "id": simulation["id"], "status": "complete", - "output": json.dumps({"should_not": "persist"}), + "output": json.dumps({"ok": True}), + "country_package_version": "1.620.0", + "policyengine_version": "0.94.2", + "data_version": "2026.04.16", + "runtime_app_name": "policyengine-app-v2", }, ) - assert response.status_code == 404 - - stored_report = test_db.query( - "SELECT * FROM report_outputs WHERE id = ?", - (report_output["id"],), + assert response.status_code == 200 + run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (simulation["id"],), ).fetchone() - assert stored_report["country_id"] == "us" - assert stored_report["status"] == "pending" - assert stored_report["output"] is None + assert run["country_package_version"] == "1.620.0" + assert run["policyengine_version"] == "0.94.2" + assert run["data_version"] == "2026.04.16" + assert run["runtime_app_name"] == "policyengine-app-v2" -def test_patch_report_output_accepts_running_status(test_db): +def test_patch_simulation_explicit_run_id_updates_only_that_run(test_db): simulation = simulation_service.create_simulation( country_id="us", - population_id="household_route_running_report", + population_id="household_route_explicit_simulation_run", population_type="household", - policy_id=45, + policy_id=78, ) - report = report_output_service.create_report_output( + simulation_service.update_simulation( country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", + simulation_id=simulation["id"], + status="complete", + output=json.dumps({"result": "initial"}), + ) + initial_run = test_db.query( + "SELECT * FROM simulation_runs WHERE simulation_id = ?", + (simulation["id"],), + ).fetchone() + rerun = simulation_run_service.create_simulation_run( + simulation["id"], + trigger_type="rerun", ) client = create_test_client() response = client.patch( - "/us/report", + "/us/simulation", json={ - "id": report["id"], - "status": "running", + "id": simulation["id"], + "simulation_run_id": rerun["id"], + "status": "complete", + "output": json.dumps({"result": "explicit rerun"}), }, ) assert response.status_code == 200 - payload = response.get_json() - assert payload["result"]["status"] == "running" - assert payload["result"]["requested_at"] is not None - assert payload["result"]["started_at"] is not None - assert payload["result"]["finished_at"] is None + initial_run_after = test_db.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (initial_run["id"],), + ).fetchone() + rerun_after = test_db.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (rerun["id"],), + ).fetchone() + assert initial_run_after["output"] == json.dumps({"result": "initial"}) + assert rerun_after["output"] == json.dumps({"result": "explicit rerun"}) -def test_get_report_output_serializes_display_run_timestamps(test_db): +def test_patch_simulation_explicit_run_id_response_uses_that_run(test_db): simulation = simulation_service.create_simulation( country_id="us", - population_id="household_route_get_timestamp", + population_id="household_route_explicit_simulation_run_response", population_type="household", - policy_id=47, + policy_id=79, ) - report = report_output_service.create_report_output( + simulation_service.update_simulation( country_id="us", - simulation_1_id=simulation["id"], - simulation_2_id=None, - year="2025", + simulation_id=simulation["id"], + status="complete", + output=json.dumps({"result": "initial"}), ) - report_output_service.update_report_output( - country_id="us", - report_id=report["id"], + older_run = simulation_run_service.create_simulation_run( + simulation["id"], status="complete", - output=json.dumps({"ok": True}), + trigger_type="rerun", + output=json.dumps({"result": "older before patch"}), + ) + newer_run = simulation_run_service.create_simulation_run( + simulation["id"], + status="complete", + trigger_type="rerun", + output=json.dumps({"result": "newer display"}), ) - run = test_db.query( - "SELECT * FROM report_output_runs WHERE report_output_id = ?", - (report["id"],), - ).fetchone() test_db.query( """ - UPDATE report_output_runs - SET requested_at = ?, started_at = ?, finished_at = ? + UPDATE simulations + SET latest_successful_run_id = ? WHERE id = ? """, - ( - "2026-05-04 12:00:00", - "2026-05-04 12:01:00", - "2026-05-04 12:02:00", - run["id"], - ), + (newer_run["id"], simulation["id"]), ) client = create_test_client() - response = client.get(f"/us/report/{report['id']}") + response = client.patch( + "/us/simulation", + json={ + "id": simulation["id"], + "simulation_run_id": older_run["id"], + "status": "complete", + "output": json.dumps({"result": "older patched"}), + }, + ) assert response.status_code == 200 payload = response.get_json() - assert payload["result"]["requested_at"] == "2026-05-04T12:00:00Z" - assert payload["result"]["started_at"] == "2026-05-04T12:01:00Z" - assert payload["result"]["finished_at"] == "2026-05-04T12:02:00Z" + assert payload["result"]["output"] == json.dumps({"result": "older patched"}) + get_response = client.get(f"/us/simulation/{simulation['id']}") + assert get_response.status_code == 200 + get_payload = get_response.get_json() + assert get_payload["result"]["output"] == json.dumps({"result": "newer display"}) -def test_patch_report_output_running_uses_active_rerun_route_path(test_db): + +def test_patch_simulation_rejects_non_string_run_metadata(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_invalid_metadata", + population_type="household", + policy_id=73, + ) + + client = create_test_client() + response = client.patch( + "/us/simulation", + json={ + "id": simulation["id"], + "country_package_version": 123, + }, + ) + + assert response.status_code == 400 + + +def test_get_report_output_wrong_country_returns_not_found(test_db): + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, year + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ("us", 999, None, get_report_output_cache_version("us"), "pending", "2025"), + ) + report_output = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + client = create_test_client() + response = client.get(f"/uk/report/{report_output['id']}") + + assert response.status_code == 404 + + +def test_get_report_output_legacy_id_wrong_country_returns_not_found(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_alias_wrong_country", + population_type="household", + policy_id=56, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_id_map_service.set_mapping( + legacy_report_output_id=2000, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=get_display_report_run_id( + test_db, canonical_report["id"] + ), + ) + + client = create_test_client() + response = client.get("/uk/report/2000") + + assert response.status_code == 404 + + +def test_get_report_output_legacy_id_resolves_to_pinned_display_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_alias", + population_type="household", + policy_id=57, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "canonical"}), + ) + legacy_run = report_run_service.create_report_output_run( + canonical_report["id"], + status="error", + trigger_type="backfill", + output=json.dumps({"result": "legacy"}), + error_message="legacy failure", + ) + report_output_id_map_service.set_mapping( + legacy_report_output_id=2001, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=legacy_run["id"], + ) + + client = create_test_client() + response = client.get("/us/report/2001") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == 2001 + assert payload["result"]["status"] == "error" + assert payload["result"]["output"] == json.dumps({"result": "legacy"}) + assert payload["result"]["error_message"] == "legacy failure" + + +def test_get_report_output_reads_malformed_legacy_row_without_runs_or_identity( + test_db, +): + household_simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_legacy_malformed", + population_type="household", + policy_id=58, + ) + geography_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/co", + population_type="geography", + policy_id=59, + ) + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, output, year + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "us", + household_simulation["id"], + geography_simulation["id"], + "r0legacy-malformed", + "error", + json.dumps({"result": "legacy-malformed"}), + "2025", + ), + ) + malformed_report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + client = create_test_client() + response = client.get(f"/us/report/{malformed_report['id']}") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == malformed_report["id"] + assert payload["result"]["status"] == "error" + assert payload["result"]["output"] == json.dumps({"result": "legacy-malformed"}) + assert payload["result"]["api_version"] == "r0legacy-malformed" + + +def test_patch_report_output_wrong_country_returns_not_found_and_does_not_mutate( + test_db, +): + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, year + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ("us", 1000, None, get_report_output_cache_version("us"), "pending", "2025"), + ) + report_output = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + client = create_test_client() + response = client.patch( + "/uk/report", + json={ + "id": report_output["id"], + "status": "complete", + "output": json.dumps({"should_not": "persist"}), + }, + ) + + assert response.status_code == 404 + + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_output["id"],), + ).fetchone() + assert stored_report["country_id"] == "us" + assert stored_report["status"] == "pending" + assert stored_report["output"] is None + + +def test_patch_report_output_accepts_running_status(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_running_report", + population_type="household", + policy_id=45, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report["id"], + "status": "running", + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["status"] == "running" + assert payload["result"]["requested_at"] is not None + assert payload["result"]["started_at"] is not None + assert payload["result"]["finished_at"] is None + + +def test_get_report_output_serializes_display_run_timestamps(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_get_timestamp", + population_type="household", + policy_id=47, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2025", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"ok": True}), + ) + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchone() + test_db.query( + """ + UPDATE report_output_runs + SET requested_at = ?, started_at = ?, finished_at = ? + WHERE id = ? + """, + ( + "2026-05-04 12:00:00", + "2026-05-04 12:01:00", + "2026-05-04 12:02:00", + run["id"], + ), + ) + + client = create_test_client() + response = client.get(f"/us/report/{report['id']}") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["requested_at"] == "2026-05-04T12:00:00Z" + assert payload["result"]["started_at"] == "2026-05-04T12:01:00Z" + assert payload["result"]["finished_at"] == "2026-05-04T12:02:00Z" + + +def test_patch_report_output_running_uses_active_rerun_route_path(test_db): simulation = simulation_service.create_simulation( country_id="us", population_id="household_route_active_running_rerun", @@ -580,3 +1214,722 @@ def test_patch_report_output_complete_promotes_active_rerun_route_path(test_db): ).fetchone() assert stored_report["active_run_id"] is None assert stored_report["latest_successful_run_id"] == rerun["id"] + + +def test_patch_report_output_explicit_run_id_updates_only_that_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_explicit_report_run", + population_type="household", + policy_id=76, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"result": "initial"}), + ) + initial_run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchone() + rerun = report_run_service.create_report_output_run( + report["id"], trigger_type="rerun" + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report["id"], + "report_output_run_id": rerun["id"], + "status": "complete", + "output": json.dumps({"result": "explicit rerun"}), + }, + ) + + assert response.status_code == 200 + initial_run_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (initial_run["id"],), + ).fetchone() + rerun_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (rerun["id"],), + ).fetchone() + assert initial_run_after["output"] == json.dumps({"result": "initial"}) + assert rerun_after["output"] == json.dumps({"result": "explicit rerun"}) + + +def test_patch_report_output_explicit_run_id_response_uses_that_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_explicit_report_run_response", + population_type="household", + policy_id=78, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"result": "initial"}), + ) + older_run = report_run_service.create_report_output_run( + report["id"], + status="complete", + trigger_type="rerun", + output=json.dumps({"result": "older before patch"}), + ) + newer_run = report_run_service.create_report_output_run( + report["id"], + status="complete", + trigger_type="rerun", + output=json.dumps({"result": "newer display"}), + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report["id"], + "report_output_run_id": older_run["id"], + "status": "complete", + "output": json.dumps({"result": "older patched"}), + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["output"] == json.dumps({"result": "older patched"}) + assert payload["result"]["finished_at"] is not None + + get_response = client.get(f"/us/report/{report['id']}") + assert get_response.status_code == 200 + get_payload = get_response.get_json() + assert get_payload["result"]["output"] == json.dumps({"result": "newer display"}) + + stored_report = test_db.query( + "SELECT latest_successful_run_id FROM report_outputs WHERE id = ?", + (report["id"],), + ).fetchone() + assert stored_report["latest_successful_run_id"] == newer_run["id"] + + +def test_patch_report_output_explicit_run_id_through_legacy_id_updates_canonical_run( + test_db, +): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_legacy_explicit_report_run", + population_type="household", + policy_id=79, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "initial"}), + ) + initial_run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (canonical_report["id"],), + ).fetchone() + rerun = report_run_service.create_report_output_run( + canonical_report["id"], trigger_type="rerun" + ) + report_output_id_map_service.set_mapping( + legacy_report_output_id=3002, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=initial_run["id"], + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": 3002, + "report_output_run_id": rerun["id"], + "status": "complete", + "output": json.dumps({"result": "legacy explicit rerun"}), + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == 3002 + + initial_run_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (initial_run["id"],), + ).fetchone() + rerun_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (rerun["id"],), + ).fetchone() + assert initial_run_after["output"] == json.dumps({"result": "initial"}) + assert rerun_after["report_output_id"] == canonical_report["id"] + assert rerun_after["output"] == json.dumps({"result": "legacy explicit rerun"}) + + +def test_patch_report_output_legacy_id_defaults_to_pinned_display_run(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_legacy_pinned_patch", + population_type="household", + policy_id=81, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "canonical initial"}), + ) + canonical_run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (canonical_report["id"],), + ).fetchone() + legacy_run = report_run_service.create_report_output_run( + canonical_report["id"], + status="pending", + trigger_type="backfill", + ) + report_output_id_map_service.set_mapping( + legacy_report_output_id=3003, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=legacy_run["id"], + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": 3003, + "status": "complete", + "output": json.dumps({"result": "legacy patched"}), + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["result"]["id"] == 3003 + assert payload["result"]["output"] == json.dumps({"result": "legacy patched"}) + + canonical_run_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (canonical_run["id"],), + ).fetchone() + legacy_run_after = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (legacy_run["id"],), + ).fetchone() + assert canonical_run_after["output"] == json.dumps({"result": "canonical initial"}) + assert legacy_run_after["output"] == json.dumps({"result": "legacy patched"}) + + +def test_create_report_rerun_via_canonical_id_creates_canonical_linked_runs(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_canonical_rerun", + population_type="household", + policy_id=80, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "canonical"}), + ) + + client = create_test_client() + response = client.post(f"/us/report/{canonical_report['id']}/rerun", json={}) + + assert response.status_code == 201 + result = response.get_json()["result"] + assert result["requested_report_output_id"] == canonical_report["id"] + assert result["report_output_id"] == canonical_report["id"] + assert len(result["simulation_run_ids"]) == 1 + + report_runs = test_db.query( + """ + SELECT * FROM report_output_runs + WHERE report_output_id = ? + ORDER BY run_sequence + """, + (canonical_report["id"],), + ).fetchall() + assert len(report_runs) == 2 + assert report_runs[0]["trigger_type"] == "initial" + assert report_runs[1]["id"] == result["report_output_run_id"] + assert report_runs[1]["trigger_type"] == "rerun" + assert report_runs[1]["status"] == "pending" + + simulation_run = test_db.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (result["simulation_run_ids"][0],), + ).fetchone() + assert simulation_run["report_output_run_id"] == result["report_output_run_id"] + assert simulation_run["input_position"] == 1 + + +def test_create_report_rerun_via_legacy_id_creates_canonical_linked_runs(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="household_route_legacy_rerun", + population_type="household", + policy_id=77, + ) + canonical_report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=canonical_report["id"], + status="complete", + output=json.dumps({"result": "canonical"}), + ) + report_output_id_map_service.set_mapping( + legacy_report_output_id=3001, + canonical_report_output_id=canonical_report["id"], + display_report_output_run_id=get_display_report_run_id( + test_db, canonical_report["id"] + ), + ) + + client = create_test_client() + response = client.post("/us/report/3001/rerun", json={}) + + assert response.status_code == 201 + result = response.get_json()["result"] + assert result["requested_report_output_id"] == 3001 + assert result["report_output_id"] == canonical_report["id"] + assert len(result["simulation_run_ids"]) == 1 + + report_run = test_db.query( + "SELECT * FROM report_output_runs WHERE id = ?", + (result["report_output_run_id"],), + ).fetchone() + assert report_run["report_output_id"] == canonical_report["id"] + assert report_run["trigger_type"] == "rerun" + + simulation_run = test_db.query( + "SELECT * FROM simulation_runs WHERE id = ?", + (result["simulation_run_ids"][0],), + ).fetchone() + assert simulation_run["report_output_run_id"] == result["report_output_run_id"] + assert simulation_run["input_position"] == 1 + + +def test_create_report_rerun_for_comparison_report_creates_two_linked_simulation_runs( + test_db, +): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nc", + population_type="geography", + policy_id=81, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nc", + population_type="geography", + policy_id=82, + ) + report = report_output_service.create_report_output( + country_id="us", + simulation_1_id=baseline_simulation["id"], + simulation_2_id=reform_simulation["id"], + year="2026", + ) + report_output_service.update_report_output( + country_id="us", + report_id=report["id"], + status="complete", + output=json.dumps({"result": "comparison"}), + ) + + client = create_test_client() + response = client.post(f"/us/report/{report['id']}/rerun", json={}) + + assert response.status_code == 201 + result = response.get_json()["result"] + assert result["report_output_id"] == report["id"] + assert len(result["simulation_run_ids"]) == 2 + + linked_simulation_runs = test_db.query( + """ + SELECT * FROM simulation_runs + WHERE report_output_run_id = ? + ORDER BY input_position + """, + (result["report_output_run_id"],), + ).fetchall() + assert [run["simulation_id"] for run in linked_simulation_runs] == [ + baseline_simulation["id"], + reform_simulation["id"], + ] + assert [run["input_position"] for run in linked_simulation_runs] == [1, 2] + assert [run["status"] for run in linked_simulation_runs] == [ + "pending", + "pending", + ] + + +def test_create_report_rerun_rejects_report_with_missing_linked_simulation(test_db): + test_db.query( + """ + INSERT INTO report_outputs ( + country_id, simulation_1_id, simulation_2_id, api_version, status, year + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ("us", 987654, None, get_report_output_cache_version("us"), "complete", "2026"), + ) + report = test_db.query( + "SELECT * FROM report_outputs ORDER BY id DESC LIMIT 1" + ).fetchone() + + client = create_test_client() + response = client.post(f"/us/report/{report['id']}/rerun", json={}) + + assert response.status_code == 400 + assert "Simulation #987654 not found" in response.get_data(as_text=True) + + report_runs = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchall() + assert report_runs == [] + + +def test_report_rerun_http_lifecycle_patches_linked_runs_and_reads_display( + test_db, +): + client = create_test_client() + simulation_response = client.post( + "/us/simulation", + json={ + "population_id": "household_route_http_lifecycle", + "population_type": "household", + "policy_id": 83, + }, + ) + assert simulation_response.status_code == 201 + simulation = simulation_response.get_json()["result"] + + report_response = client.post( + "/us/report", + json={ + "simulation_1_id": simulation["id"], + "simulation_2_id": None, + "year": "2026", + }, + ) + assert report_response.status_code == 201 + report = report_response.get_json()["result"] + + initial_patch_response = client.patch( + "/us/report", + json={ + "id": report["id"], + "status": "complete", + "output": json.dumps({"result": "initial report"}), + }, + ) + assert initial_patch_response.status_code == 200 + + rerun_response = client.post(f"/us/report/{report['id']}/rerun", json={}) + assert rerun_response.status_code == 201 + rerun = rerun_response.get_json()["result"] + assert len(rerun["simulation_run_ids"]) == 1 + + simulation_patch_response = client.patch( + "/us/simulation", + json={ + "id": simulation["id"], + "simulation_run_id": rerun["simulation_run_ids"][0], + "status": "complete", + "output": json.dumps({"result": "rerun simulation"}), + }, + ) + assert simulation_patch_response.status_code == 200 + + report_patch_response = client.patch( + "/us/report", + json={ + "id": report["id"], + "report_output_run_id": rerun["report_output_run_id"], + "status": "complete", + "output": json.dumps({"result": "rerun report"}), + }, + ) + assert report_patch_response.status_code == 200 + + get_response = client.get(f"/us/report/{report['id']}") + assert get_response.status_code == 200 + result = get_response.get_json()["result"] + assert result["id"] == report["id"] + assert result["status"] == "complete" + assert result["output"] == json.dumps({"result": "rerun report"}) + + report_rows = test_db.query("SELECT * FROM report_outputs").fetchall() + report_runs = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report["id"],), + ).fetchall() + linked_simulation_runs = test_db.query( + """ + SELECT * FROM simulation_runs + WHERE report_output_run_id = ? + """, + (rerun["report_output_run_id"],), + ).fetchall() + assert len(report_rows) == 1 + assert len(report_runs) == 2 + assert len(linked_simulation_runs) == 1 + + +def test_patch_report_output_persists_run_metadata_fields(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/wa", + population_type="geography", + policy_id=48, + ) + report_output = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report_output["id"], + "status": "complete", + "output": json.dumps({"result": "ok"}), + "country_package_version": "1.621.0", + "policyengine_version": "0.95.0", + "data_version": "2026.04.17", + "runtime_app_name": "policyengine-app-v2", + "resolved_dataset": "enhanced_us_household", + }, + ) + + assert response.status_code == 200 + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_output["id"],), + ).fetchone() + assert run["country_package_version"] == "1.621.0" + assert run["policyengine_version"] == "0.95.0" + assert run["data_version"] == "2026.04.17" + assert run["runtime_app_name"] == "policyengine-app-v2" + assert run["resolved_dataset"] == "enhanced_us_household" + + +def test_patch_report_output_rejects_non_string_run_metadata(test_db): + simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/mt", + population_type="geography", + policy_id=74, + ) + report_output = report_output_service.create_report_output( + country_id="us", + simulation_1_id=simulation["id"], + simulation_2_id=None, + year="2026", + ) + + client = create_test_client() + response = client.patch( + "/us/report", + json={ + "id": report_output["id"], + "policyengine_version": 123, + }, + ) + + assert response.status_code == 400 + + +def test_patch_report_output_preserves_stored_explicit_report_spec(test_db): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/or", + population_type="geography", + policy_id=49, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/or", + population_type="geography", + policy_id=50, + ) + + client = create_test_client() + create_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/or", + "baseline_policy_id": 49, + "reform_policy_id": 50, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + report_id = create_response.get_json()["result"]["id"] + + patch_response = client.patch( + "/us/report", + json={ + "id": report_id, + "status": "complete", + "output": json.dumps({"result": "ok"}), + }, + ) + + assert patch_response.status_code == 200 + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_id,), + ).fetchone() + assert stored_report["report_spec_status"] == "explicit" + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_id,), + ).fetchone() + assert run is not None + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} + + +def test_patch_report_output_metadata_only_preserves_stored_explicit_report_spec( + test_db, +): + baseline_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nj", + population_type="geography", + policy_id=51, + ) + reform_simulation = simulation_service.create_simulation( + country_id="us", + population_id="state/nj", + population_type="geography", + policy_id=52, + ) + + client = create_test_client() + create_response = client.post( + "/us/report", + json={ + "simulation_1_id": baseline_simulation["id"], + "simulation_2_id": reform_simulation["id"], + "year": "2026", + "report_spec_schema_version": 1, + "report_spec": { + "country_id": "us", + "report_kind": "economy_comparison", + "time_period": "2026", + "region": "state/nj", + "baseline_policy_id": 51, + "reform_policy_id": 52, + "dataset": "enhanced_us_household", + "target": "cliff", + "options": {"view": "tax"}, + }, + }, + ) + report_id = create_response.get_json()["result"]["id"] + + patch_response = client.patch( + "/us/report", + json={ + "id": report_id, + "policyengine_version": "0.95.1", + "runtime_app_name": "policyengine-app-v2", + }, + ) + + assert patch_response.status_code == 200 + stored_report = test_db.query( + "SELECT * FROM report_outputs WHERE id = ?", + (report_id,), + ).fetchone() + assert stored_report["report_spec_status"] == "explicit" + report_spec = stored_report["report_spec_json"] + if isinstance(report_spec, str): + report_spec = json.loads(report_spec) + assert report_spec["dataset"] == "enhanced_us_household" + assert report_spec["target"] == "cliff" + assert report_spec["options"] == {"view": "tax"} + + run = test_db.query( + "SELECT * FROM report_output_runs WHERE report_output_id = ?", + (report_id,), + ).fetchone() + assert run is not None + assert run["policyengine_version"] == "0.95.1" + assert run["runtime_app_name"] == "policyengine-app-v2" + snapshot = run["report_spec_snapshot_json"] + if isinstance(snapshot, str): + snapshot = json.loads(snapshot) + assert snapshot["dataset"] == "enhanced_us_household" + assert snapshot["target"] == "cliff" + assert snapshot["options"] == {"view": "tax"} diff --git a/uv.lock b/uv.lock index 480fe985f..b99629c78 100644 --- a/uv.lock +++ b/uv.lock @@ -146,6 +146,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "alembic" +version = "1.18.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/13/8b084e0f2efb0275a1d534838844926f798bd766566b1375174e2448cd31/alembic-1.18.4.tar.gz", hash = "sha256:cb6e1fd84b6174ab8dbb2329f86d631ba9559dd78df550b57804d607672cedbc", size = 2056725, upload-time = "2026-02-10T16:00:47.195Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/29/6533c317b74f707ea28f8d633734dbda2119bbadfc61b2f3640ba835d0f7/alembic-1.18.4-py3-none-any.whl", hash = "sha256:a5ed4adcf6d8a4cb575f3d759f071b03cd6e5c7618eb796cb52497be25bfe19a", size = 263893, upload-time = "2026-02-10T16:00:49.997Z" }, +] + [[package]] name = "altair" version = "6.1.0" @@ -1811,6 +1825,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/91/6c074015990f4f656f7b69a5c2d15924906ce0bc19c7014ac953493c0cf0/linecheck-0.1.0-py3-none-any.whl", hash = "sha256:73c6b29790521fa711b00df7cd60af4caf7004337d8710606881fbecb0d1bc83", size = 2767, upload-time = "2022-07-16T13:06:17.01Z" }, ] +[[package]] +name = "mako" +version = "1.3.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/62/791b31e69ae182791ec67f04850f2f062716bbd205483d63a215f3e062d3/mako-1.3.12.tar.gz", hash = "sha256:9f778e93289bd410bb35daadeb4fc66d95a746f0b75777b942088b7fd7af550a", size = 400219, upload-time = "2026-04-28T19:01:08.512Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/b1/a0ec7a5a9db730a08daef1fdfb8090435b82465abbf758a596f0ea88727e/mako-1.3.12-py3-none-any.whl", hash = "sha256:8f61569480282dbf557145ce441e4ba888be453c30989f879f0d652e39f53ea9", size = 78521, upload-time = "2026-04-28T19:01:10.393Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -2622,9 +2648,10 @@ sdist = { url = "https://files.pythonhosted.org/packages/a0/f3/eeea7dab690e46cd9 [[package]] name = "policyengine-api" -version = "3.40.11" +version = "3.40.12" source = { editable = "." } dependencies = [ + { name = "alembic" }, { name = "anthropic" }, { name = "assertpy" }, { name = "click" }, @@ -2670,6 +2697,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "alembic", specifier = ">=1.13.0" }, { name = "anthropic" }, { name = "assertpy" }, { name = "build", marker = "extra == 'dev'" },