From c954f403eafcb52729f7a5201b3c957ca1b9e1dc Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 12 Jan 2026 22:11:36 +0300 Subject: [PATCH 1/2] feat: add version filtering to metadata endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add optional `tax_benefit_model_version_id` filter to /parameters/, /variables/, and /parameter-values/ endpoints. When only model name is provided (without version ID), endpoints now default to returning data from the latest version (by created_at timestamp). Changes: - Add shared tax_benefit_models service with helper functions - Update /parameters/ endpoint with version filtering - Update /variables/ endpoint with version filtering - Update /parameter-values/ endpoint with model name and version filtering - Refactor /analysis/economic-impact to use shared helper - Add comprehensive tests for version filtering behavior 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/policyengine_api/api/analysis.py | 32 +-- src/policyengine_api/api/parameter_values.py | 22 ++- src/policyengine_api/api/parameters.py | 23 ++- src/policyengine_api/api/variables.py | 23 ++- .../services/tax_benefit_models.py | 107 ++++++++++ test_fixtures/fixtures_parameters.py | 98 +++++++++- tests/test_parameters.py | 184 ++++++++++++++++++ tests/test_variables.py | 136 ++++++++++++- 8 files changed, 568 insertions(+), 57 deletions(-) create mode 100644 src/policyengine_api/services/tax_benefit_models.py diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index c9aa86d..276144c 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -35,10 +35,9 @@ ReportStatus, Simulation, SimulationStatus, - TaxBenefitModel, - TaxBenefitModelVersion, ) from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import get_latest_model_version def get_traceparent() -> str | None: @@ -110,33 +109,6 @@ class EconomicImpactResponse(BaseModel): program_statistics: list[ProgramStatisticsRead] | None = None -def _get_model_version( - tax_benefit_model_name: str, session: Session -) -> TaxBenefitModelVersion: - """Get the latest tax benefit model version.""" - model_name = tax_benefit_model_name.replace("_", "-") - - model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) - ).first() - if not model: - raise HTTPException( - status_code=404, detail=f"Tax benefit model {model_name} not found" - ) - - version = session.exec( - select(TaxBenefitModelVersion) - .where(TaxBenefitModelVersion.model_id == model.id) - .order_by(TaxBenefitModelVersion.created_at.desc()) - ).first() - if not version: - raise HTTPException( - status_code=404, detail=f"No version found for model {model_name}" - ) - - return version - - def _get_deterministic_simulation_id( dataset_id: UUID, model_version_id: UUID, @@ -576,7 +548,7 @@ def economic_impact( ) # Get model version - model_version = _get_model_version(request.tax_benefit_model_name, session) + model_version = get_latest_model_version(request.tax_benefit_model_name, session) # Get or create simulations baseline_sim = _get_or_create_simulation( diff --git a/src/policyengine_api/api/parameter_values.py b/src/policyengine_api/api/parameter_values.py index 4668ab8..2298746 100644 --- a/src/policyengine_api/api/parameter_values.py +++ b/src/policyengine_api/api/parameter_values.py @@ -12,8 +12,9 @@ from fastapi import APIRouter, Depends, HTTPException from sqlmodel import Session, or_, select -from policyengine_api.models import ParameterValue, ParameterValueRead +from policyengine_api.models import Parameter, ParameterValue, ParameterValueRead from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import resolve_model_version_id router = APIRouter(prefix="/parameter-values", tags=["parameter-values"]) @@ -23,6 +24,8 @@ def list_parameter_values( parameter_id: UUID | None = None, policy_id: UUID | None = None, current: bool = False, + tax_benefit_model_name: str | None = None, + tax_benefit_model_version_id: UUID | None = None, skip: int = 0, limit: int = 100, session: Session = Depends(get_session), @@ -37,6 +40,12 @@ def list_parameter_values( policy_id: Filter by a specific policy reform. current: If true, only return values that are currently in effect (start_date <= now and (end_date is null or end_date > now)). + tax_benefit_model_name: Filter by country model name. + Use "policyengine-uk" for UK parameter values. + Use "policyengine-us" for US parameter values. + When specified without version_id, returns values from the latest version. + tax_benefit_model_version_id: Filter by specific model version UUID. + Takes precedence over tax_benefit_model_name if both are provided. """ query = select(ParameterValue) @@ -46,6 +55,17 @@ def list_parameter_values( if policy_id: query = query.where(ParameterValue.policy_id == policy_id) + # Resolve version ID from either explicit ID or model name (defaults to latest) + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + + if version_id: + # Join through Parameter to filter by model version + query = query.join(Parameter).where( + Parameter.tax_benefit_model_version_id == version_id + ) + if current: now = datetime.now(timezone.utc) query = query.where( diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index db029e5..b78928e 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -14,10 +14,9 @@ from policyengine_api.models import ( Parameter, ParameterRead, - TaxBenefitModel, - TaxBenefitModelVersion, ) from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import resolve_model_version_id router = APIRouter(prefix="/parameters", tags=["parameters"]) @@ -28,6 +27,7 @@ def list_parameters( limit: int = 100, search: str | None = None, tax_benefit_model_name: str | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available parameters with pagination and search. @@ -37,19 +37,22 @@ def list_parameters( Args: search: Filter by parameter name, label, or description. - tax_benefit_model_name: Filter by country model. + tax_benefit_model_name: Filter by country model name. Use "policyengine-uk" for UK parameters. Use "policyengine-us" for US parameters. + When specified without version_id, returns parameters from the latest version. + tax_benefit_model_version_id: Filter by specific model version UUID. + Takes precedence over tax_benefit_model_name if both are provided. """ query = select(Parameter) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) - ) + # Resolve version ID from either explicit ID or model name (defaults to latest) + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + + if version_id: + query = query.where(Parameter.tax_benefit_model_version_id == version_id) if search: # Case-insensitive search using ILIKE diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index d660b1b..14b3764 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -12,12 +12,11 @@ from sqlmodel import Session, select from policyengine_api.models import ( - TaxBenefitModel, - TaxBenefitModelVersion, Variable, VariableRead, ) from policyengine_api.services.database import get_session +from policyengine_api.services.tax_benefit_models import resolve_model_version_id router = APIRouter(prefix="/variables", tags=["variables"]) @@ -28,6 +27,7 @@ def list_variables( limit: int = 100, search: str | None = None, tax_benefit_model_name: str | None = None, + tax_benefit_model_version_id: UUID | None = None, session: Session = Depends(get_session), ): """List available variables with pagination and search. @@ -38,19 +38,22 @@ def list_variables( Args: search: Filter by variable name, label, or description. - tax_benefit_model_name: Filter by country model. + tax_benefit_model_name: Filter by country model name. Use "policyengine-uk" for UK variables. Use "policyengine-us" for US variables. + When specified without version_id, returns variables from the latest version. + tax_benefit_model_version_id: Filter by specific model version UUID. + Takes precedence over tax_benefit_model_name if both are provided. """ query = select(Variable) - # Filter by tax benefit model name (country) - if tax_benefit_model_name: - query = ( - query.join(TaxBenefitModelVersion) - .join(TaxBenefitModel) - .where(TaxBenefitModel.name == tax_benefit_model_name) - ) + # Resolve version ID from either explicit ID or model name (defaults to latest) + version_id = resolve_model_version_id( + tax_benefit_model_name, tax_benefit_model_version_id, session + ) + + if version_id: + query = query.where(Variable.tax_benefit_model_version_id == version_id) if search: # Case-insensitive search using ILIKE diff --git a/src/policyengine_api/services/tax_benefit_models.py b/src/policyengine_api/services/tax_benefit_models.py new file mode 100644 index 0000000..72cfe40 --- /dev/null +++ b/src/policyengine_api/services/tax_benefit_models.py @@ -0,0 +1,107 @@ +"""Tax benefit model utilities. + +Shared utilities for working with tax benefit models and versions. +""" + +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session, select + +from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion + + +def get_latest_model_version( + tax_benefit_model_name: str, session: Session +) -> TaxBenefitModelVersion: + """Get the latest tax benefit model version for a given model name. + + Args: + tax_benefit_model_name: The model name (e.g., "policyengine-us" or "policyengine_us"). + session: Database session. + + Returns: + The latest TaxBenefitModelVersion for the model. + + Raises: + HTTPException: If model or version not found. + """ + # Normalize model name (allow underscores or hyphens) + model_name = tax_benefit_model_name.replace("_", "-") + + model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + ).first() + if not model: + raise HTTPException( + status_code=404, detail=f"Tax benefit model {model_name} not found" + ) + + version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + if not version: + raise HTTPException( + status_code=404, detail=f"No version found for model {model_name}" + ) + + return version + + +def get_model_version_by_id( + version_id: UUID, session: Session +) -> TaxBenefitModelVersion: + """Get a specific tax benefit model version by ID. + + Args: + version_id: The UUID of the model version. + session: Database session. + + Returns: + The TaxBenefitModelVersion with the given ID. + + Raises: + HTTPException: If version not found. + """ + version = session.get(TaxBenefitModelVersion, version_id) + if not version: + raise HTTPException( + status_code=404, detail=f"Tax benefit model version {version_id} not found" + ) + return version + + +def resolve_model_version_id( + tax_benefit_model_name: str | None, + tax_benefit_model_version_id: UUID | None, + session: Session, +) -> UUID | None: + """Resolve the model version ID from either explicit ID or model name. + + If version_id is provided, validates and returns it. + If only model_name is provided, returns the latest version ID for that model. + If neither is provided, returns None. + + Args: + tax_benefit_model_name: Optional model name to get latest version for. + tax_benefit_model_version_id: Optional explicit version ID. + session: Database session. + + Returns: + The resolved version ID, or None if no filtering requested. + + Raises: + HTTPException: If specified version/model not found. + """ + if tax_benefit_model_version_id: + # Validate the version exists + version = get_model_version_by_id(tax_benefit_model_version_id, session) + return version.id + elif tax_benefit_model_name: + # Get the latest version for this model + version = get_latest_model_version(tax_benefit_model_name, session) + return version.id + else: + return None diff --git a/test_fixtures/fixtures_parameters.py b/test_fixtures/fixtures_parameters.py index ff69b0e..c369ea0 100644 --- a/test_fixtures/fixtures_parameters.py +++ b/test_fixtures/fixtures_parameters.py @@ -1,6 +1,6 @@ """Fixtures and helpers for parameter-related tests.""" -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone import pytest @@ -10,6 +10,7 @@ Policy, TaxBenefitModel, TaxBenefitModelVersion, + Variable, ) # ----------------------------------------------------------------------------- @@ -103,3 +104,98 @@ def create_parameter_values_batch( for pv in pvs: session.refresh(pv) return pvs + + +def create_model_and_version( + session, + model_name: str = "test-model", + version_string: str = "1.0.0", + created_at_offset_days: int = 0, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create a TaxBenefitModel and TaxBenefitModelVersion. + + Args: + session: Database session. + model_name: Name for the model. + version_string: Version string (e.g., "1.0.0"). + created_at_offset_days: Days to offset created_at (negative for past). + + Returns: + Tuple of (model, version). + """ + # Check if model already exists + from sqlmodel import select + + existing_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == model_name) + ).first() + + if existing_model: + model = existing_model + else: + model = TaxBenefitModel(name=model_name, description=f"Test model {model_name}") + session.add(model) + session.commit() + session.refresh(model) + + created_at = datetime.now(timezone.utc) + timedelta(days=created_at_offset_days) + version = TaxBenefitModelVersion( + model_id=model.id, + version=version_string, + description=f"Version {version_string}", + created_at=created_at, + ) + session.add(version) + session.commit() + session.refresh(version) + return model, version + + +def create_variable( + session, model_version, name: str, entity: str = "person" +) -> Variable: + """Create and persist a Variable.""" + var = Variable( + name=name, + entity=entity, + description=f"Test variable {name}", + data_type="float", + tax_benefit_model_version_id=model_version.id, + ) + session.add(var) + session.commit() + session.refresh(var) + return var + + +# ----------------------------------------------------------------------------- +# Multi-Version Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def two_model_versions(session): + """Create a model with two versions (old and new) for testing version filtering. + + Returns a dict with: + - model: The TaxBenefitModel + - old_version: The older TaxBenefitModelVersion (created 10 days ago) + - new_version: The newer TaxBenefitModelVersion (created now) + """ + model, old_version = create_model_and_version( + session, + model_name="policyengine-us", + version_string="1.0.0", + created_at_offset_days=-10, + ) + _, new_version = create_model_and_version( + session, + model_name="policyengine-us", + version_string="2.0.0", + created_at_offset_days=0, + ) + return { + "model": model, + "old_version": old_version, + "new_version": new_version, + } diff --git a/tests/test_parameters.py b/tests/test_parameters.py index f95016b..e35bc21 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -9,7 +9,9 @@ create_parameter_value, create_parameter_values_batch, create_policy, + create_variable, model_version, # noqa: F401 - pytest fixture + two_model_versions, # noqa: F401 - pytest fixture ) # ----------------------------------------------------------------------------- @@ -42,6 +44,121 @@ def test__given_nonexistent_parameter_id__then_returns_404(client): assert response.status_code == 404 +# ----------------------------------------------------------------------------- +# Parameter Version Filtering Tests +# ----------------------------------------------------------------------------- + + +def test__given_model_name_filter__then_returns_only_latest_version_parameters( + client, + session, + two_model_versions, # noqa: F811 +): + """GET /parameters?tax_benefit_model_name=X returns only latest version parameters.""" + # Given + old_version = two_model_versions["old_version"] + new_version = two_model_versions["new_version"] + + old_param = create_parameter( + session, old_version, "gov.old.param", "Old Version Param" + ) + new_param = create_parameter( + session, new_version, "gov.new.param", "New Version Param" + ) + + # When + response = client.get("/parameters?tax_benefit_model_name=policyengine-us") + + # Then + assert response.status_code == 200 + data = response.json() + param_ids = [p["id"] for p in data] + assert str(new_param.id) in param_ids + assert str(old_param.id) not in param_ids + + +def test__given_version_id_filter__then_returns_only_that_version_parameters( + client, + session, + two_model_versions, # noqa: F811 +): + """GET /parameters?tax_benefit_model_version_id=X returns only that version's parameters.""" + # Given + old_version = two_model_versions["old_version"] + new_version = two_model_versions["new_version"] + + old_param = create_parameter( + session, old_version, "gov.version.old", "Old Version Param" + ) + new_param = create_parameter( + session, new_version, "gov.version.new", "New Version Param" + ) + + # When + response = client.get( + f"/parameters?tax_benefit_model_version_id={old_version.id}" + ) + + # Then + assert response.status_code == 200 + data = response.json() + param_ids = [p["id"] for p in data] + assert str(old_param.id) in param_ids + assert str(new_param.id) not in param_ids + + +def test__given_both_model_name_and_version_id__then_version_id_takes_precedence( + client, + session, + two_model_versions, # noqa: F811 +): + """GET /parameters with both filters uses version_id (takes precedence).""" + # Given + old_version = two_model_versions["old_version"] + new_version = two_model_versions["new_version"] + + old_param = create_parameter( + session, old_version, "gov.precedence.old", "Old Precedence Param" + ) + create_parameter(session, new_version, "gov.precedence.new", "New Precedence Param") + + # When - pass both filters, version_id should win + response = client.get( + f"/parameters?tax_benefit_model_name=policyengine-us" + f"&tax_benefit_model_version_id={old_version.id}" + ) + + # Then - should get old version params (version_id takes precedence) + assert response.status_code == 200 + data = response.json() + param_ids = [p["id"] for p in data] + assert str(old_param.id) in param_ids + + +def test__given_nonexistent_model_name__then_returns_404(client): + """GET /parameters with non-existent model name returns 404.""" + # Given + nonexistent_model = "nonexistent-model" + + # When + response = client.get(f"/parameters?tax_benefit_model_name={nonexistent_model}") + + # Then + assert response.status_code == 404 + + +def test__given_nonexistent_version_id__then_returns_404(client): + """GET /parameters with non-existent version ID returns 404.""" + # Given + fake_version_id = uuid4() + + # When + response = client.get(f"/parameters?tax_benefit_model_version_id={fake_version_id}") + + # Then + assert response.status_code == 404 + + # ----------------------------------------------------------------------------- # Parameter Value Endpoint Tests # ----------------------------------------------------------------------------- @@ -200,5 +317,72 @@ def test__given_skip_parameter__then_skips_specified_results( assert len(response.json()) == 2 # 5 total - 3 skipped = 2 remaining +# ----------------------------------------------------------------------------- +# Parameter Value Version Filtering Tests +# ----------------------------------------------------------------------------- + + +def test__given_model_name_filter_on_values__then_returns_only_latest_version_values( + client, + session, + two_model_versions, # noqa: F811 +): + """GET /parameter-values?tax_benefit_model_name=X returns only latest version values.""" + # Given + old_version = two_model_versions["old_version"] + new_version = two_model_versions["new_version"] + + old_param = create_parameter( + session, old_version, "gov.pv.old.param", "Old PV Param" + ) + new_param = create_parameter( + session, new_version, "gov.pv.new.param", "New PV Param" + ) + old_pv = create_parameter_value(session, old_param.id, 100) + new_pv = create_parameter_value(session, new_param.id, 200) + + # When + response = client.get("/parameter-values?tax_benefit_model_name=policyengine-us") + + # Then + assert response.status_code == 200 + data = response.json() + pv_ids = [p["id"] for p in data] + assert str(new_pv.id) in pv_ids + assert str(old_pv.id) not in pv_ids + + +def test__given_version_id_filter_on_values__then_returns_only_that_version_values( + client, + session, + two_model_versions, # noqa: F811 +): + """GET /parameter-values?tax_benefit_model_version_id=X returns only that version's values.""" + # Given + old_version = two_model_versions["old_version"] + new_version = two_model_versions["new_version"] + + old_param = create_parameter( + session, old_version, "gov.pv.version.old", "Old PV Version Param" + ) + new_param = create_parameter( + session, new_version, "gov.pv.version.new", "New PV Version Param" + ) + old_pv = create_parameter_value(session, old_param.id, 100) + new_pv = create_parameter_value(session, new_param.id, 200) + + # When + response = client.get( + f"/parameter-values?tax_benefit_model_version_id={old_version.id}" + ) + + # Then + assert response.status_code == 200 + data = response.json() + pv_ids = [p["id"] for p in data] + assert str(old_pv.id) in pv_ids + assert str(new_pv.id) not in pv_ids + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_variables.py b/tests/test_variables.py index 76cb47a..b8c7009 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -4,18 +4,144 @@ import pytest +from test_fixtures.fixtures_parameters import ( + create_variable, + two_model_versions, # noqa: F401 - pytest fixture +) -def test_list_variables(client): - """List variables returns a list.""" - response = client.get("/variables") + +# ----------------------------------------------------------------------------- +# Variable Endpoint Basic Tests +# ----------------------------------------------------------------------------- + + +def test__given_variables_endpoint_called__then_returns_list(client): + """GET /variables returns a list.""" + # Given + endpoint = "/variables" + + # When + response = client.get(endpoint) + + # Then assert response.status_code == 200 assert isinstance(response.json(), list) -def test_get_variable_not_found(client): - """Get non-existent variable returns 404.""" +def test__given_nonexistent_variable_id__then_returns_404(client): + """GET /variables/{id} returns 404 for non-existent variable.""" + # Given fake_id = uuid4() + + # When response = client.get(f"/variables/{fake_id}") + + # Then + assert response.status_code == 404 + + +# ----------------------------------------------------------------------------- +# Variable Version Filtering Tests +# ----------------------------------------------------------------------------- + + +def test__given_model_name_filter__then_returns_only_latest_version_variables( + client, + session, + two_model_versions, # noqa: F811 +): + """GET /variables?tax_benefit_model_name=X returns only latest version variables.""" + # Given + old_version = two_model_versions["old_version"] + new_version = two_model_versions["new_version"] + + old_var = create_variable(session, old_version, "old_variable") + new_var = create_variable(session, new_version, "new_variable") + + # When + response = client.get("/variables?tax_benefit_model_name=policyengine-us") + + # Then + assert response.status_code == 200 + data = response.json() + var_ids = [v["id"] for v in data] + assert str(new_var.id) in var_ids + assert str(old_var.id) not in var_ids + + +def test__given_version_id_filter__then_returns_only_that_version_variables( + client, + session, + two_model_versions, # noqa: F811 +): + """GET /variables?tax_benefit_model_version_id=X returns only that version's variables.""" + # Given + old_version = two_model_versions["old_version"] + new_version = two_model_versions["new_version"] + + old_var = create_variable(session, old_version, "version_old_var") + new_var = create_variable(session, new_version, "version_new_var") + + # When + response = client.get( + f"/variables?tax_benefit_model_version_id={old_version.id}" + ) + + # Then + assert response.status_code == 200 + data = response.json() + var_ids = [v["id"] for v in data] + assert str(old_var.id) in var_ids + assert str(new_var.id) not in var_ids + + +def test__given_both_model_name_and_version_id__then_version_id_takes_precedence( + client, + session, + two_model_versions, # noqa: F811 +): + """GET /variables with both filters uses version_id (takes precedence).""" + # Given + old_version = two_model_versions["old_version"] + new_version = two_model_versions["new_version"] + + old_var = create_variable(session, old_version, "precedence_old_var") + create_variable(session, new_version, "precedence_new_var") + + # When - pass both filters, version_id should win + response = client.get( + f"/variables?tax_benefit_model_name=policyengine-us" + f"&tax_benefit_model_version_id={old_version.id}" + ) + + # Then - should get old version vars (version_id takes precedence) + assert response.status_code == 200 + data = response.json() + var_ids = [v["id"] for v in data] + assert str(old_var.id) in var_ids + + +def test__given_nonexistent_model_name__then_returns_404(client): + """GET /variables with non-existent model name returns 404.""" + # Given + nonexistent_model = "nonexistent-model" + + # When + response = client.get(f"/variables?tax_benefit_model_name={nonexistent_model}") + + # Then + assert response.status_code == 404 + + +def test__given_nonexistent_version_id__then_returns_404(client): + """GET /variables with non-existent version ID returns 404.""" + # Given + fake_version_id = uuid4() + + # When + response = client.get(f"/variables?tax_benefit_model_version_id={fake_version_id}") + + # Then assert response.status_code == 404 From cc57095601af2d05d515b9098a52acb2a9b94535 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 13 Jan 2026 00:36:57 +0300 Subject: [PATCH 2/2] fix: update aggregate tests to use valid simulations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tests for aggregate and change-aggregate endpoints were failing because they passed random UUIDs for simulation IDs, but the endpoints now validate that simulations exist before creating aggregates. Changes: - Add create_dataset and create_simulation factory functions to fixtures - Add simulation_fixture for tests needing a complete simulation setup - Update test_outputs.py to create real simulations and mock Modal calls - Update test_change_aggregates.py similarly with two_simulations fixture - Add tests for 404 cases when simulations don't exist 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- test_fixtures/fixtures_parameters.py | 71 ++++++++++++ tests/test_change_aggregates.py | 157 ++++++++++++++++++++++++--- tests/test_outputs.py | 92 ++++++++++++++-- 3 files changed, 297 insertions(+), 23 deletions(-) diff --git a/test_fixtures/fixtures_parameters.py b/test_fixtures/fixtures_parameters.py index c369ea0..d8a1941 100644 --- a/test_fixtures/fixtures_parameters.py +++ b/test_fixtures/fixtures_parameters.py @@ -5,9 +5,12 @@ import pytest from policyengine_api.models import ( + Dataset, Parameter, ParameterValue, Policy, + Simulation, + SimulationStatus, TaxBenefitModel, TaxBenefitModelVersion, Variable, @@ -199,3 +202,71 @@ def two_model_versions(session): "old_version": old_version, "new_version": new_version, } + + +# ----------------------------------------------------------------------------- +# Dataset and Simulation Factory Functions +# ----------------------------------------------------------------------------- + + +def create_dataset( + session, + model: TaxBenefitModel, + name: str = "test-dataset", + year: int = 2024, +) -> Dataset: + """Create and persist a Dataset.""" + dataset = Dataset( + name=name, + description=f"Test dataset {name}", + filepath=f"test/{name}.h5", + year=year, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + return dataset + + +def create_simulation( + session, + dataset: Dataset, + model_version: TaxBenefitModelVersion, + status: SimulationStatus = SimulationStatus.COMPLETED, +) -> Simulation: + """Create and persist a Simulation.""" + simulation = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=model_version.id, + status=status, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + return simulation + + +@pytest.fixture +def simulation_fixture(session): + """Create a complete simulation setup for testing. + + Returns a dict with: + - model: The TaxBenefitModel + - version: The TaxBenefitModelVersion + - dataset: The Dataset + - simulation: The Simulation + """ + model, version = create_model_and_version( + session, + model_name="policyengine-uk", + version_string="1.0.0", + ) + dataset = create_dataset(session, model, name="test-frs-2024") + simulation = create_simulation(session, dataset, version) + return { + "model": model, + "version": version, + "dataset": dataset, + "simulation": simulation, + } diff --git a/tests/test_change_aggregates.py b/tests/test_change_aggregates.py index 0b8ef87..df061e4 100644 --- a/tests/test_change_aggregates.py +++ b/tests/test_change_aggregates.py @@ -1,68 +1,199 @@ """Tests for change aggregate endpoints.""" +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from test_fixtures.fixtures_parameters import ( + create_dataset, + create_model_and_version, + create_simulation, +) + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def two_simulations(session): + """Create two simulations (baseline and reform) for change aggregate tests.""" + model, version = create_model_and_version( + session, + model_name="policyengine-uk", + version_string="1.0.0", + ) + dataset = create_dataset(session, model, name="test-frs-2024") + baseline = create_simulation(session, dataset, version) + reform = create_simulation(session, dataset, version) + return { + "baseline": baseline, + "reform": reform, + } + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + def test_list_change_aggregates_empty(client): """List change aggregates returns empty list initially.""" - response = client.get("/outputs/change-aggregates") + # Given + endpoint = "/outputs/change-aggregates" + + # When + response = client.get(endpoint) + + # Then assert response.status_code == 200 assert isinstance(response.json(), list) -def test_create_single_change_aggregate(client): - """Create a single change aggregate.""" +@patch("policyengine_api.api.change_aggregates.modal.Function.from_name") +def test__given_valid_simulations__when_creating_single_change_aggregate__then_returns_200( + mock_modal_function, + client, + session, + two_simulations, +): + """Create a single change aggregate with valid simulations.""" + # Given + mock_fn = MagicMock() + mock_modal_function.return_value = mock_fn + baseline = two_simulations["baseline"] + reform = two_simulations["reform"] + + # When response = client.post( "/outputs/change-aggregates", json=[ { - "baseline_simulation_id": str(uuid4()), - "reform_simulation_id": str(uuid4()), + "baseline_simulation_id": str(baseline.id), + "reform_simulation_id": str(reform.id), "variable": "net_income", "aggregate_type": "sum", } ], ) + + # Then assert response.status_code == 200 data = response.json() assert isinstance(data, list) assert len(data) == 1 assert data[0]["variable"] == "net_income" assert data[0]["aggregate_type"] == "sum" + mock_fn.spawn.assert_called_once() + +@patch("policyengine_api.api.change_aggregates.modal.Function.from_name") +def test__given_valid_simulations__when_creating_multiple_change_aggregates__then_returns_200( + mock_modal_function, + client, + session, + two_simulations, +): + """Create multiple change aggregates in one request with valid simulations.""" + # Given + mock_fn = MagicMock() + mock_modal_function.return_value = mock_fn + baseline = two_simulations["baseline"] + reform = two_simulations["reform"] -def test_create_multiple_change_aggregates(client): - """Create multiple change aggregates in one request.""" - baseline_id = str(uuid4()) - reform_id = str(uuid4()) + # When response = client.post( "/outputs/change-aggregates", json=[ { - "baseline_simulation_id": baseline_id, - "reform_simulation_id": reform_id, + "baseline_simulation_id": str(baseline.id), + "reform_simulation_id": str(reform.id), "variable": "income_tax", "aggregate_type": "sum", }, { - "baseline_simulation_id": baseline_id, - "reform_simulation_id": reform_id, + "baseline_simulation_id": str(baseline.id), + "reform_simulation_id": str(reform.id), "variable": "benefits", "aggregate_type": "mean", }, ], ) + + # Then assert response.status_code == 200 data = response.json() assert len(data) == 2 + assert mock_fn.spawn.call_count == 2 + + +def test__given_nonexistent_baseline_simulation__when_creating_change_aggregate__then_returns_404( + client, + session, + two_simulations, +): + """Create change aggregate with non-existent baseline simulation returns 404.""" + # Given + reform = two_simulations["reform"] + fake_baseline_id = uuid4() + + # When + response = client.post( + "/outputs/change-aggregates", + json=[ + { + "baseline_simulation_id": str(fake_baseline_id), + "reform_simulation_id": str(reform.id), + "variable": "net_income", + "aggregate_type": "sum", + } + ], + ) + + # Then + assert response.status_code == 404 + assert "baseline" in response.json()["detail"].lower() + + +def test__given_nonexistent_reform_simulation__when_creating_change_aggregate__then_returns_404( + client, + session, + two_simulations, +): + """Create change aggregate with non-existent reform simulation returns 404.""" + # Given + baseline = two_simulations["baseline"] + fake_reform_id = uuid4() + + # When + response = client.post( + "/outputs/change-aggregates", + json=[ + { + "baseline_simulation_id": str(baseline.id), + "reform_simulation_id": str(fake_reform_id), + "variable": "net_income", + "aggregate_type": "sum", + } + ], + ) + + # Then + assert response.status_code == 404 + assert "reform" in response.json()["detail"].lower() def test_get_change_aggregate_not_found(client): """Get non-existent change aggregate returns 404.""" + # Given fake_id = uuid4() + + # When response = client.get(f"/outputs/change-aggregates/{fake_id}") + + # Then assert response.status_code == 404 diff --git a/tests/test_outputs.py b/tests/test_outputs.py index f4e18d7..db4c4a8 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -1,71 +1,143 @@ """Tests for aggregate outputs endpoints.""" +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from test_fixtures.fixtures_parameters import ( + create_dataset, + create_model_and_version, + create_simulation, + simulation_fixture, # noqa: F401 - pytest fixture +) + def test_list_aggregates_empty(client): """List aggregates returns empty list initially.""" - response = client.get("/outputs/aggregates") + # Given + endpoint = "/outputs/aggregates" + + # When + response = client.get(endpoint) + + # Then assert response.status_code == 200 assert isinstance(response.json(), list) -def test_create_single_aggregate(client): - """Create a single aggregate output.""" +@patch("policyengine_api.api.outputs.modal.Function.from_name") +def test__given_valid_simulation__when_creating_single_aggregate__then_returns_200( + mock_modal_function, + client, + session, + simulation_fixture, # noqa: F811 +): + """Create a single aggregate output with valid simulation.""" + # Given + mock_fn = MagicMock() + mock_modal_function.return_value = mock_fn + simulation = simulation_fixture["simulation"] + + # When response = client.post( "/outputs/aggregates", json=[ { - "simulation_id": str(uuid4()), + "simulation_id": str(simulation.id), "variable": "net_income", "aggregate_type": "sum", } ], ) + + # Then assert response.status_code == 200 data = response.json() assert isinstance(data, list) assert len(data) == 1 assert data[0]["variable"] == "net_income" assert data[0]["aggregate_type"] == "sum" + mock_fn.spawn.assert_called_once() + +@patch("policyengine_api.api.outputs.modal.Function.from_name") +def test__given_valid_simulation__when_creating_multiple_aggregates__then_returns_200( + mock_modal_function, + client, + session, + simulation_fixture, # noqa: F811 +): + """Create multiple aggregate outputs in one request with valid simulation.""" + # Given + mock_fn = MagicMock() + mock_modal_function.return_value = mock_fn + simulation = simulation_fixture["simulation"] -def test_create_multiple_aggregates(client): - """Create multiple aggregate outputs in one request.""" - sim_id = str(uuid4()) + # When response = client.post( "/outputs/aggregates", json=[ { - "simulation_id": sim_id, + "simulation_id": str(simulation.id), "variable": "income_tax", "aggregate_type": "sum", }, { - "simulation_id": sim_id, + "simulation_id": str(simulation.id), "variable": "household_count", "aggregate_type": "count", }, { - "simulation_id": sim_id, + "simulation_id": str(simulation.id), "variable": "mean_income", "aggregate_type": "mean", }, ], ) + + # Then assert response.status_code == 200 data = response.json() assert len(data) == 3 variables = {d["variable"] for d in data} assert variables == {"income_tax", "household_count", "mean_income"} + assert mock_fn.spawn.call_count == 3 + + +def test__given_nonexistent_simulation__when_creating_aggregate__then_returns_404( + client, +): + """Create aggregate with non-existent simulation returns 404.""" + # Given + fake_simulation_id = uuid4() + + # When + response = client.post( + "/outputs/aggregates", + json=[ + { + "simulation_id": str(fake_simulation_id), + "variable": "net_income", + "aggregate_type": "sum", + } + ], + ) + + # Then + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() def test_get_aggregate_not_found(client): """Get non-existent aggregate returns 404.""" + # Given fake_id = uuid4() + + # When response = client.get(f"/outputs/aggregates/{fake_id}") + + # Then assert response.status_code == 404 assert response.json()["detail"] == "Aggregate not found"