From c21dc2e5c77beb953f19a826766e0955619db91c Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Tue, 13 Jan 2026 10:26:31 +0000 Subject: [PATCH 1/3] feat: support multiple households in calculation endpoints Changed entity inputs from single dicts to lists of dicts, enabling callers to submit multiple households with entity relational dataframes. The caller specifies {entity}_id fields in each entity and person_{entity}_id fields in people to link them together. For simple single-household requests, IDs can be omitted and all people will default to entity 0. Co-Authored-By: Claude --- src/policyengine_api/api/household.py | 668 +++++++++++++++++++------- src/policyengine_api/modal_app.py | 365 ++++++++++++-- tests/test_household.py | 118 ++++- 3 files changed, 920 insertions(+), 231 deletions(-) diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 9ed1aef..0e89b5e 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -55,12 +55,34 @@ class HouseholdCalculateRequest(BaseModel): CORRECT: {"employment_income": 70000, "age": 40} WRONG: {"employment_income": {"2024": 70000}, "age": {"2024": 40}} - Example US request: + Supports multiple households via entity relational dataframes. Include + {entity}_id fields in each entity and person_{entity}_id fields in people + to link them together. + + Example US request (single household, simple): { "tax_benefit_model_name": "policyengine_us", "people": [{"employment_income": 70000, "age": 40}], - "tax_unit": {"state_code": "CA"}, - "household": {"state_fips": 6}, + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024 + } + + Example US request (multiple households): + { + "tax_benefit_model_name": "policyengine_us", + "people": [ + {"person_id": 0, "person_household_id": 0, "person_tax_unit_id": 0, "age": 40, "employment_income": 70000}, + {"person_id": 1, "person_household_id": 1, "person_tax_unit_id": 1, "age": 30, "employment_income": 50000} + ], + "tax_unit": [ + {"tax_unit_id": 0, "state_code": "CA"}, + {"tax_unit_id": 1, "state_code": "NY"} + ], + "household": [ + {"household_id": 0, "state_fips": 6}, + {"household_id": 1, "state_fips": 36} + ], "year": 2024 } @@ -68,7 +90,7 @@ class HouseholdCalculateRequest(BaseModel): { "tax_benefit_model_name": "policyengine_uk", "people": [{"employment_income": 50000, "age": 30}], - "household": {}, + "household": [{}], "year": 2026 } """ @@ -77,27 +99,31 @@ class HouseholdCalculateRequest(BaseModel): description="Which country model to use" ) people: list[dict[str, Any]] = Field( - description="List of people with flat variable values (e.g. [{'age': 30, 'employment_income': 50000}]). Do NOT use time-period format." + description="List of people with flat variable values. Include person_id and person_{entity}_id fields to link to entities." ) - benunit: dict[str, Any] = Field( - default_factory=dict, description="UK benefit unit variables (flat values)" + benunit: list[dict[str, Any]] = Field( + default_factory=list, + description="UK benefit units. Include benunit_id to link with person_benunit_id in people.", ) - marital_unit: dict[str, Any] = Field( - default_factory=dict, description="US marital unit variables (flat values)" + marital_unit: list[dict[str, Any]] = Field( + default_factory=list, + description="US marital units. Include marital_unit_id to link with person_marital_unit_id in people.", ) - family: dict[str, Any] = Field( - default_factory=dict, description="US family variables (flat values)" + family: list[dict[str, Any]] = Field( + default_factory=list, + description="US families. Include family_id to link with person_family_id in people.", ) - spm_unit: dict[str, Any] = Field( - default_factory=dict, description="US SPM unit variables (flat values)" + spm_unit: list[dict[str, Any]] = Field( + default_factory=list, + description="US SPM units. Include spm_unit_id to link with person_spm_unit_id in people.", ) - tax_unit: dict[str, Any] = Field( - default_factory=dict, - description="US tax unit variables (flat values, e.g. {'state_code': 'CA'})", + tax_unit: list[dict[str, Any]] = Field( + default_factory=list, + description="US tax units. Include tax_unit_id to link with person_tax_unit_id in people.", ) - household: dict[str, Any] = Field( - default_factory=dict, - description="Household variables (flat values, e.g. {'state_fips': 6} for US)", + household: list[dict[str, Any]] = Field( + default_factory=list, + description="Households. Include household_id to link with person_household_id in people.", ) year: int | None = Field( default=None, @@ -120,7 +146,7 @@ class HouseholdCalculateResponse(BaseModel): family: list[dict[str, Any]] | None = None spm_unit: list[dict[str, Any]] | None = None tax_unit: list[dict[str, Any]] | None = None - household: dict[str, Any] + household: list[dict[str, Any]] class HouseholdJobResponse(BaseModel): @@ -143,13 +169,14 @@ class HouseholdImpactRequest(BaseModel): """Request body for household impact comparison. Same format as HouseholdCalculateRequest - use flat values, NOT time-period dictionaries. + Supports multiple households via entity relational dataframes. Example: { "tax_benefit_model_name": "policyengine_us", "people": [{"employment_income": 70000, "age": 40}], - "tax_unit": {"state_code": "CA"}, - "household": {"state_fips": 6}, + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], "year": 2024, "policy_id": "uuid-of-reform-policy" } @@ -159,25 +186,31 @@ class HouseholdImpactRequest(BaseModel): description="Which country model to use" ) people: list[dict[str, Any]] = Field( - description="List of people with flat variable values. Do NOT use time-period format." + description="List of people with flat variable values. Include person_id and person_{entity}_id fields to link to entities." ) - benunit: dict[str, Any] = Field( - default_factory=dict, description="UK benefit unit variables (flat values)" + benunit: list[dict[str, Any]] = Field( + default_factory=list, + description="UK benefit units. Include benunit_id to link with person_benunit_id in people.", ) - marital_unit: dict[str, Any] = Field( - default_factory=dict, description="US marital unit variables (flat values)" + marital_unit: list[dict[str, Any]] = Field( + default_factory=list, + description="US marital units. Include marital_unit_id to link with person_marital_unit_id in people.", ) - family: dict[str, Any] = Field( - default_factory=dict, description="US family variables (flat values)" + family: list[dict[str, Any]] = Field( + default_factory=list, + description="US families. Include family_id to link with person_family_id in people.", ) - spm_unit: dict[str, Any] = Field( - default_factory=dict, description="US SPM unit variables (flat values)" + spm_unit: list[dict[str, Any]] = Field( + default_factory=list, + description="US SPM units. Include spm_unit_id to link with person_spm_unit_id in people.", ) - tax_unit: dict[str, Any] = Field( - default_factory=dict, description="US tax unit variables (flat values)" + tax_unit: list[dict[str, Any]] = Field( + default_factory=list, + description="US tax units. Include tax_unit_id to link with person_tax_unit_id in people.", ) - household: dict[str, Any] = Field( - default_factory=dict, description="Household variables (flat values)" + household: list[dict[str, Any]] = Field( + default_factory=list, + description="Households. Include household_id to link with person_household_id in people.", ) year: int | None = Field( default=None, description="Simulation year (default: 2024 for US, 2026 for UK)" @@ -212,70 +245,26 @@ class HouseholdImpactJobStatusResponse(BaseModel): def _run_local_household_uk( job_id: str, people: list[dict], - benunit: dict, - household: dict, + benunit: list[dict], + household: list[dict], year: int, policy_data: dict | None, session: Session, ) -> None: - """Run UK household calculation locally.""" - from datetime import datetime, timezone + """Run UK household calculation locally. - from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.analysis import ( - UKHouseholdInput, - calculate_household_impact, - ) + Supports multiple households via entity relational dataframes. + """ + from datetime import datetime, timezone try: - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ParameterValue as PEParameterValue - from policyengine.core.policy import Policy as PEPolicy - - pe_param_values = [] - param_lookup = {p.name: p for p in uk_latest.parameters} - for pv in policy_data.get("parameter_values", []): - pe_param = param_lookup.get(pv["parameter_name"]) - if pe_param: - pe_pv = PEParameterValue( - parameter=pe_param, - value=pv["value"], - start_date=datetime.fromisoformat(pv["start_date"]) - if pv.get("start_date") - else None, - end_date=datetime.fromisoformat(pv["end_date"]) - if pv.get("end_date") - else None, - ) - pe_param_values.append(pe_pv) - policy = PEPolicy( - name=policy_data.get("name", ""), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) - - pe_input = UKHouseholdInput( - people=people, - benunit=benunit, - household=household, - year=year, - ) - - result = calculate_household_impact(pe_input, policy=policy) + result = _calculate_household_uk(people, benunit, household, year, policy_data) # Update job with result job = session.get(HouseholdJob, job_id) if job: job.status = HouseholdJobStatus.COMPLETED - job.result = _sanitize_for_json( - { - "person": result.person, - "benunit": result.benunit, - "household": result.household, - } - ) + job.result = _sanitize_for_json(result) job.completed_at = datetime.now(timezone.utc) session.add(job) session.commit() @@ -294,82 +283,197 @@ def _run_local_household_uk( raise +def _calculate_household_uk( + people: list[dict], + benunit: list[dict], + household: list[dict], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate UK household(s) and return result dict. + + Supports multiple households via entity relational dataframes. If entity IDs + are not provided, defaults to single household with all people in it. + """ + import tempfile + from datetime import datetime + from pathlib import Path + + import pandas as pd + from policyengine.core import Simulation + from microdf import MicroDataFrame + from policyengine.tax_benefit_models.uk import uk_latest + from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset + from policyengine.tax_benefit_models.uk.datasets import UKYearData + + n_people = len(people) + n_benunits = max(1, len(benunit)) + n_households = max(1, len(household)) + + # Build person data with defaults + person_data = { + "person_id": list(range(n_people)), + "person_benunit_id": [0] * n_people, + "person_household_id": [0] * n_people, + "person_weight": [1.0] * n_people, + } + # Add user-provided person fields + for i, person in enumerate(people): + for key, value in person.items(): + if key not in person_data: + person_data[key] = [0.0] * n_people + person_data[key][i] = value + + # Build benunit data with defaults + benunit_data = { + "benunit_id": list(range(n_benunits)), + "benunit_weight": [1.0] * n_benunits, + } + for i, bu in enumerate(benunit if benunit else [{}]): + for key, value in bu.items(): + if key not in benunit_data: + benunit_data[key] = [0.0] * n_benunits + benunit_data[key][i] = value + + # Build household data with defaults + household_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + "region": ["LONDON"] * n_households, + "tenure_type": ["RENT_PRIVATELY"] * n_households, + "council_tax": [0.0] * n_households, + "rent": [0.0] * n_households, + } + for i, hh in enumerate(household if household else [{}]): + for key, value in hh.items(): + if key not in household_data: + household_data[key] = [0.0] * n_households + household_data[key][i] = value + + # Create MicroDataFrames + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight") + household_df = MicroDataFrame( + pd.DataFrame(household_data), weights="household_weight" + ) + + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUKDataset( + name="Household calculation", + description="Household(s) for calculation", + filepath=filepath, + year=year, + data=UKYearData( + person=person_df, + benunit=benunit_df, + household=household_df, + ), + ) + + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ParameterValue as PEParameterValue + from policyengine.core.policy import Policy as PEPolicy + + pe_param_values = [] + param_lookup = {p.name: p for p in uk_latest.parameters} + for pv in policy_data.get("parameter_values", []): + pe_param = param_lookup.get(pv["parameter_name"]) + if pe_param: + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv["value"], + start_date=datetime.fromisoformat(pv["start_date"]) + if pv.get("start_date") + else None, + end_date=datetime.fromisoformat(pv["end_date"]) + if pv.get("end_date") + else None, + ) + pe_param_values.append(pe_pv) + policy = PEPolicy( + name=policy_data.get("name", ""), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) + + # Run simulation + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + policy=policy, + ) + simulation.run() + + # Extract outputs + output_data = simulation.output_dataset.data + + def safe_convert(value): + try: + return float(value) + except (ValueError, TypeError): + return str(value) + + person_outputs = [] + for i in range(n_people): + person_dict = {} + for var in uk_latest.entity_variables["person"]: + person_dict[var] = safe_convert(output_data.person[var].iloc[i]) + person_outputs.append(person_dict) + + benunit_outputs = [] + for i in range(len(output_data.benunit)): + benunit_dict = {} + for var in uk_latest.entity_variables["benunit"]: + benunit_dict[var] = safe_convert(output_data.benunit[var].iloc[i]) + benunit_outputs.append(benunit_dict) + + household_outputs = [] + for i in range(len(output_data.household)): + household_dict = {} + for var in uk_latest.entity_variables["household"]: + household_dict[var] = safe_convert(output_data.household[var].iloc[i]) + household_outputs.append(household_dict) + + return { + "person": person_outputs, + "benunit": benunit_outputs, + "household": household_outputs, + } + + def _run_local_household_us( job_id: str, people: list[dict], - marital_unit: dict, - family: dict, - spm_unit: dict, - tax_unit: dict, - household: dict, + marital_unit: list[dict], + family: list[dict], + spm_unit: list[dict], + tax_unit: list[dict], + household: list[dict], year: int, policy_data: dict | None, session: Session, ) -> None: - """Run US household calculation locally.""" - from datetime import datetime, timezone + """Run US household calculation locally. - from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.analysis import ( - USHouseholdInput, - calculate_household_impact, - ) + Supports multiple households via entity relational dataframes. + """ + from datetime import datetime, timezone try: - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ParameterValue as PEParameterValue - from policyengine.core.policy import Policy as PEPolicy - - pe_param_values = [] - param_lookup = {p.name: p for p in us_latest.parameters} - for pv in policy_data.get("parameter_values", []): - pe_param = param_lookup.get(pv["parameter_name"]) - if pe_param: - pe_pv = PEParameterValue( - parameter=pe_param, - value=pv["value"], - start_date=datetime.fromisoformat(pv["start_date"]) - if pv.get("start_date") - else None, - end_date=datetime.fromisoformat(pv["end_date"]) - if pv.get("end_date") - else None, - ) - pe_param_values.append(pe_pv) - policy = PEPolicy( - name=policy_data.get("name", ""), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) - - pe_input = USHouseholdInput( - people=people, - marital_unit=marital_unit, - family=family, - spm_unit=spm_unit, - tax_unit=tax_unit, - household=household, - year=year, + result = _calculate_household_us( + people, marital_unit, family, spm_unit, tax_unit, household, year, policy_data ) - result = calculate_household_impact(pe_input, policy=policy) - # Update job with result job = session.get(HouseholdJob, job_id) if job: job.status = HouseholdJobStatus.COMPLETED - job.result = _sanitize_for_json( - { - "person": result.person, - "marital_unit": result.marital_unit, - "family": result.family, - "spm_unit": result.spm_unit, - "tax_unit": result.tax_unit, - "household": result.household, - } - ) + job.result = _sanitize_for_json(result) job.completed_at = datetime.now(timezone.utc) session.add(job) session.commit() @@ -388,6 +492,215 @@ def _run_local_household_us( raise +def _calculate_household_us( + people: list[dict], + marital_unit: list[dict], + family: list[dict], + spm_unit: list[dict], + tax_unit: list[dict], + household: list[dict], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate US household(s) and return result dict. + + Supports multiple households via entity relational dataframes. If entity IDs + are not provided, defaults to single household with all people in it. + """ + import tempfile + from datetime import datetime + from pathlib import Path + + import pandas as pd + from policyengine.core import Simulation + from microdf import MicroDataFrame + from policyengine.tax_benefit_models.us import us_latest + from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset + from policyengine.tax_benefit_models.us.datasets import USYearData + + n_people = len(people) + n_households = max(1, len(household)) + n_marital_units = max(1, len(marital_unit)) + n_families = max(1, len(family)) + n_spm_units = max(1, len(spm_unit)) + n_tax_units = max(1, len(tax_unit)) + + # Build person data with defaults + person_data = { + "person_id": list(range(n_people)), + "person_household_id": [0] * n_people, + "person_marital_unit_id": [0] * n_people, + "person_family_id": [0] * n_people, + "person_spm_unit_id": [0] * n_people, + "person_tax_unit_id": [0] * n_people, + "person_weight": [1.0] * n_people, + } + for i, person in enumerate(people): + for key, value in person.items(): + if key not in person_data: + person_data[key] = [0.0] * n_people + person_data[key][i] = value + + # Build household data + household_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + } + for i, hh in enumerate(household if household else [{}]): + for key, value in hh.items(): + if key not in household_data: + household_data[key] = [0.0] * n_households + household_data[key][i] = value + + # Build marital_unit data + marital_unit_data = { + "marital_unit_id": list(range(n_marital_units)), + "marital_unit_weight": [1.0] * n_marital_units, + } + for i, mu in enumerate(marital_unit if marital_unit else [{}]): + for key, value in mu.items(): + if key not in marital_unit_data: + marital_unit_data[key] = [0.0] * n_marital_units + marital_unit_data[key][i] = value + + # Build family data + family_data = { + "family_id": list(range(n_families)), + "family_weight": [1.0] * n_families, + } + for i, fam in enumerate(family if family else [{}]): + for key, value in fam.items(): + if key not in family_data: + family_data[key] = [0.0] * n_families + family_data[key][i] = value + + # Build spm_unit data + spm_unit_data = { + "spm_unit_id": list(range(n_spm_units)), + "spm_unit_weight": [1.0] * n_spm_units, + } + for i, spm in enumerate(spm_unit if spm_unit else [{}]): + for key, value in spm.items(): + if key not in spm_unit_data: + spm_unit_data[key] = [0.0] * n_spm_units + spm_unit_data[key][i] = value + + # Build tax_unit data + tax_unit_data = { + "tax_unit_id": list(range(n_tax_units)), + "tax_unit_weight": [1.0] * n_tax_units, + } + for i, tu in enumerate(tax_unit if tax_unit else [{}]): + for key, value in tu.items(): + if key not in tax_unit_data: + tax_unit_data[key] = [0.0] * n_tax_units + tax_unit_data[key][i] = value + + # Create MicroDataFrames + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + household_df = MicroDataFrame( + pd.DataFrame(household_data), weights="household_weight" + ) + marital_unit_df = MicroDataFrame( + pd.DataFrame(marital_unit_data), weights="marital_unit_weight" + ) + family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") + spm_unit_df = MicroDataFrame(pd.DataFrame(spm_unit_data), weights="spm_unit_weight") + tax_unit_df = MicroDataFrame(pd.DataFrame(tax_unit_data), weights="tax_unit_weight") + + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUSDataset( + name="Household calculation", + description="Household(s) for calculation", + filepath=filepath, + year=year, + data=USYearData( + person=person_df, + household=household_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + tax_unit=tax_unit_df, + ), + ) + + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ParameterValue as PEParameterValue + from policyengine.core.policy import Policy as PEPolicy + + pe_param_values = [] + param_lookup = {p.name: p for p in us_latest.parameters} + for pv in policy_data.get("parameter_values", []): + pe_param = param_lookup.get(pv["parameter_name"]) + if pe_param: + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv["value"], + start_date=datetime.fromisoformat(pv["start_date"]) + if pv.get("start_date") + else None, + end_date=datetime.fromisoformat(pv["end_date"]) + if pv.get("end_date") + else None, + ) + pe_param_values.append(pe_pv) + policy = PEPolicy( + name=policy_data.get("name", ""), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) + + # Run simulation + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, + policy=policy, + ) + simulation.run() + + # Extract outputs + output_data = simulation.output_dataset.data + + def safe_convert(value): + try: + return float(value) + except (ValueError, TypeError): + return str(value) + + def extract_entity_outputs(entity_name: str, entity_data, n_rows: int) -> list[dict]: + outputs = [] + for i in range(n_rows): + row_dict = {} + for var in us_latest.entity_variables[entity_name]: + row_dict[var] = safe_convert(entity_data[var].iloc[i]) + outputs.append(row_dict) + return outputs + + return { + "person": extract_entity_outputs("person", output_data.person, n_people), + "marital_unit": extract_entity_outputs( + "marital_unit", output_data.marital_unit, len(output_data.marital_unit) + ), + "family": extract_entity_outputs( + "family", output_data.family, len(output_data.family) + ), + "spm_unit": extract_entity_outputs( + "spm_unit", output_data.spm_unit, len(output_data.spm_unit) + ), + "tax_unit": extract_entity_outputs( + "tax_unit", output_data.tax_unit, len(output_data.tax_unit) + ), + "household": extract_entity_outputs( + "household", output_data.household, len(output_data.household) + ), + } + + def _trigger_modal_household( job_id: str, request: HouseholdCalculateRequest, @@ -596,7 +909,7 @@ def get_household_job_status( family=job.result.get("family"), spm_unit=job.result.get("spm_unit"), tax_unit=job.result.get("tax_unit"), - household=job.result.get("household", {}), + household=job.result.get("household", []), ) return HouseholdJobStatusResponse( @@ -613,40 +926,33 @@ def _compute_impact( """Compute difference between baseline and reform.""" impact = {} + def compute_entity_diff( + baseline_list: list[dict], reform_list: list[dict] + ) -> list[dict]: + """Compute differences for a list of entity dicts.""" + entity_impact = [] + for b_entity, r_entity in zip(baseline_list, reform_list): + entity_diff = {} + for key in b_entity: + if key in r_entity: + baseline_val = b_entity[key] + reform_val = r_entity[key] + if isinstance(baseline_val, (int, float)) and isinstance( + reform_val, (int, float) + ): + entity_diff[key] = { + "baseline": baseline_val, + "reform": reform_val, + "change": reform_val - baseline_val, + } + entity_impact.append(entity_diff) + return entity_impact + # Compute household-level differences - hh_impact = {} - for key in baseline.household: - if key in reform.household: - baseline_val = baseline.household[key] - reform_val = reform.household[key] - if isinstance(baseline_val, (int, float)) and isinstance( - reform_val, (int, float) - ): - hh_impact[key] = { - "baseline": baseline_val, - "reform": reform_val, - "change": reform_val - baseline_val, - } - impact["household"] = hh_impact + impact["household"] = compute_entity_diff(baseline.household, reform.household) # Compute person-level differences - person_impact = [] - for i, (bp, rp) in enumerate(zip(baseline.person, reform.person)): - person_diff = {} - for key in bp: - if key in rp: - baseline_val = bp[key] - reform_val = rp[key] - if isinstance(baseline_val, (int, float)) and isinstance( - reform_val, (int, float) - ): - person_diff[key] = { - "baseline": baseline_val, - "reform": reform_val, - "change": reform_val - baseline_val, - } - person_impact.append(person_diff) - impact["person"] = person_impact + impact["person"] = compute_entity_diff(baseline.person, reform.person) return impact @@ -837,7 +1143,7 @@ def get_household_impact_job_status( family=baseline_job.result.get("family"), spm_unit=baseline_job.result.get("spm_unit"), tax_unit=baseline_job.result.get("tax_unit"), - household=baseline_job.result.get("household", {}), + household=baseline_job.result.get("household", []), ) reform_result = HouseholdCalculateResponse( person=reform_job.result.get("person", []), @@ -846,7 +1152,7 @@ def get_household_impact_job_status( family=reform_job.result.get("family"), spm_unit=reform_job.result.get("spm_unit"), tax_unit=reform_job.result.get("tax_unit"), - household=reform_job.result.get("household", {}), + household=reform_job.result.get("household", []), ) impact = _compute_impact(baseline_result, reform_result) diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 4769404..1aa8119 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -214,18 +214,24 @@ def download_dataset( def simulate_household_uk( job_id: str, people: list[dict], - benunit: dict, - household: dict, + benunit: list[dict], + household: list[dict], year: int, policy_data: dict | None, dynamic_data: dict | None, traceparent: str | None = None, ) -> None: - """Calculate UK household and write result to database.""" + """Calculate UK household(s) and write result to database. + + Supports multiple households via entity relational dataframes. + """ import json + import tempfile from datetime import datetime, timezone + from pathlib import Path import logfire + import pandas as pd from sqlmodel import Session, create_engine configure_logfire("policyengine-modal-uk", traceparent) @@ -236,10 +242,82 @@ def simulate_household_uk( engine = create_engine(database_url) try: + from policyengine.core import Simulation + from microdf import MicroDataFrame from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.analysis import ( - UKHouseholdInput, - calculate_household_impact, + from policyengine.tax_benefit_models.uk.datasets import ( + PolicyEngineUKDataset, + ) + from policyengine.tax_benefit_models.uk.datasets import UKYearData + + n_people = len(people) + n_benunits = max(1, len(benunit)) + n_households = max(1, len(household)) + + # Build person data with defaults + person_data = { + "person_id": list(range(n_people)), + "person_benunit_id": [0] * n_people, + "person_household_id": [0] * n_people, + "person_weight": [1.0] * n_people, + } + for i, person in enumerate(people): + for key, value in person.items(): + if key not in person_data: + person_data[key] = [0.0] * n_people + person_data[key][i] = value + + # Build benunit data + benunit_data = { + "benunit_id": list(range(n_benunits)), + "benunit_weight": [1.0] * n_benunits, + } + for i, bu in enumerate(benunit if benunit else [{}]): + for key, value in bu.items(): + if key not in benunit_data: + benunit_data[key] = [0.0] * n_benunits + benunit_data[key][i] = value + + # Build household data + household_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + "region": ["LONDON"] * n_households, + "tenure_type": ["RENT_PRIVATELY"] * n_households, + "council_tax": [0.0] * n_households, + "rent": [0.0] * n_households, + } + for i, hh in enumerate(household if household else [{}]): + for key, value in hh.items(): + if key not in household_data: + household_data[key] = [0.0] * n_households + household_data[key][i] = value + + # Create MicroDataFrames + person_df = MicroDataFrame( + pd.DataFrame(person_data), weights="person_weight" + ) + benunit_df = MicroDataFrame( + pd.DataFrame(benunit_data), weights="benunit_weight" + ) + household_df = MicroDataFrame( + pd.DataFrame(household_data), weights="household_weight" + ) + + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUKDataset( + name="Household calculation", + description="Household(s) for calculation", + filepath=filepath, + year=year, + data=UKYearData( + person=person_df, + benunit=benunit_df, + household=household_df, + ), ) # Build policy if provided @@ -274,15 +352,48 @@ def simulate_household_uk( parameter_values=pe_param_values, ) - pe_input = UKHouseholdInput( - people=people, - benunit=benunit, - household=household, - year=year, - ) - - with logfire.span("calculate_household_impact"): - result = calculate_household_impact(pe_input, policy=policy) + # Run simulation + with logfire.span("run_simulation"): + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + policy=policy, + ) + simulation.run() + + # Extract outputs + output_data = simulation.output_dataset.data + + def safe_convert(value): + try: + return float(value) + except (ValueError, TypeError): + return str(value) + + person_outputs = [] + for i in range(n_people): + person_dict = {} + for var in uk_latest.entity_variables["person"]: + person_dict[var] = safe_convert(output_data.person[var].iloc[i]) + person_outputs.append(person_dict) + + benunit_outputs = [] + for i in range(len(output_data.benunit)): + benunit_dict = {} + for var in uk_latest.entity_variables["benunit"]: + benunit_dict[var] = safe_convert( + output_data.benunit[var].iloc[i] + ) + benunit_outputs.append(benunit_dict) + + household_outputs = [] + for i in range(len(output_data.household)): + household_dict = {} + for var in uk_latest.entity_variables["household"]: + household_dict[var] = safe_convert( + output_data.household[var].iloc[i] + ) + household_outputs.append(household_dict) # Write result to database with Session(engine) as session: @@ -300,9 +411,9 @@ def simulate_household_uk( "job_id": job_id, "result": json.dumps( { - "person": result.person, - "benunit": result.benunit, - "household": result.household, + "person": person_outputs, + "benunit": benunit_outputs, + "household": household_outputs, } ), "completed_at": datetime.now(timezone.utc), @@ -345,21 +456,27 @@ def simulate_household_uk( def simulate_household_us( job_id: str, people: list[dict], - marital_unit: dict, - family: dict, - spm_unit: dict, - tax_unit: dict, - household: dict, + marital_unit: list[dict], + family: list[dict], + spm_unit: list[dict], + tax_unit: list[dict], + household: list[dict], year: int, policy_data: dict | None, dynamic_data: dict | None, traceparent: str | None = None, ) -> None: - """Calculate US household and write result to database.""" + """Calculate US household(s) and write result to database. + + Supports multiple households via entity relational dataframes. + """ import json + import tempfile from datetime import datetime, timezone + from pathlib import Path import logfire + import pandas as pd from sqlmodel import Session, create_engine configure_logfire("policyengine-modal-us", traceparent) @@ -370,10 +487,129 @@ def simulate_household_us( engine = create_engine(database_url) try: + from policyengine.core import Simulation + from microdf import MicroDataFrame from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.analysis import ( - USHouseholdInput, - calculate_household_impact, + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + ) + from policyengine.tax_benefit_models.us.datasets import USYearData + + n_people = len(people) + n_households = max(1, len(household)) + n_marital_units = max(1, len(marital_unit)) + n_families = max(1, len(family)) + n_spm_units = max(1, len(spm_unit)) + n_tax_units = max(1, len(tax_unit)) + + # Build person data with defaults + person_data = { + "person_id": list(range(n_people)), + "person_household_id": [0] * n_people, + "person_marital_unit_id": [0] * n_people, + "person_family_id": [0] * n_people, + "person_spm_unit_id": [0] * n_people, + "person_tax_unit_id": [0] * n_people, + "person_weight": [1.0] * n_people, + } + for i, person in enumerate(people): + for key, value in person.items(): + if key not in person_data: + person_data[key] = [0.0] * n_people + person_data[key][i] = value + + # Build household data + household_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + } + for i, hh in enumerate(household if household else [{}]): + for key, value in hh.items(): + if key not in household_data: + household_data[key] = [0.0] * n_households + household_data[key][i] = value + + # Build marital_unit data + marital_unit_data = { + "marital_unit_id": list(range(n_marital_units)), + "marital_unit_weight": [1.0] * n_marital_units, + } + for i, mu in enumerate(marital_unit if marital_unit else [{}]): + for key, value in mu.items(): + if key not in marital_unit_data: + marital_unit_data[key] = [0.0] * n_marital_units + marital_unit_data[key][i] = value + + # Build family data + family_data = { + "family_id": list(range(n_families)), + "family_weight": [1.0] * n_families, + } + for i, fam in enumerate(family if family else [{}]): + for key, value in fam.items(): + if key not in family_data: + family_data[key] = [0.0] * n_families + family_data[key][i] = value + + # Build spm_unit data + spm_unit_data = { + "spm_unit_id": list(range(n_spm_units)), + "spm_unit_weight": [1.0] * n_spm_units, + } + for i, spm in enumerate(spm_unit if spm_unit else [{}]): + for key, value in spm.items(): + if key not in spm_unit_data: + spm_unit_data[key] = [0.0] * n_spm_units + spm_unit_data[key][i] = value + + # Build tax_unit data + tax_unit_data = { + "tax_unit_id": list(range(n_tax_units)), + "tax_unit_weight": [1.0] * n_tax_units, + } + for i, tu in enumerate(tax_unit if tax_unit else [{}]): + for key, value in tu.items(): + if key not in tax_unit_data: + tax_unit_data[key] = [0.0] * n_tax_units + tax_unit_data[key][i] = value + + # Create MicroDataFrames + person_df = MicroDataFrame( + pd.DataFrame(person_data), weights="person_weight" + ) + household_df = MicroDataFrame( + pd.DataFrame(household_data), weights="household_weight" + ) + marital_unit_df = MicroDataFrame( + pd.DataFrame(marital_unit_data), weights="marital_unit_weight" + ) + family_df = MicroDataFrame( + pd.DataFrame(family_data), weights="family_weight" + ) + spm_unit_df = MicroDataFrame( + pd.DataFrame(spm_unit_data), weights="spm_unit_weight" + ) + tax_unit_df = MicroDataFrame( + pd.DataFrame(tax_unit_data), weights="tax_unit_weight" + ) + + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUSDataset( + name="Household calculation", + description="Household(s) for calculation", + filepath=filepath, + year=year, + data=USYearData( + person=person_df, + household=household_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + tax_unit=tax_unit_df, + ), ) # Build policy if provided @@ -408,18 +644,34 @@ def simulate_household_us( parameter_values=pe_param_values, ) - pe_input = USHouseholdInput( - people=people, - marital_unit=marital_unit, - family=family, - spm_unit=spm_unit, - tax_unit=tax_unit, - household=household, - year=year, - ) - - with logfire.span("calculate_household_impact"): - result = calculate_household_impact(pe_input, policy=policy) + # Run simulation + with logfire.span("run_simulation"): + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, + policy=policy, + ) + simulation.run() + + # Extract outputs + output_data = simulation.output_dataset.data + + def safe_convert(value): + try: + return float(value) + except (ValueError, TypeError): + return str(value) + + def extract_entity_outputs( + entity_name: str, entity_data, n_rows: int + ) -> list[dict]: + outputs = [] + for i in range(n_rows): + row_dict = {} + for var in us_latest.entity_variables[entity_name]: + row_dict[var] = safe_convert(entity_data[var].iloc[i]) + outputs.append(row_dict) + return outputs # Write result to database with Session(engine) as session: @@ -437,12 +689,34 @@ def simulate_household_us( "job_id": job_id, "result": json.dumps( { - "person": result.person, - "marital_unit": result.marital_unit, - "family": result.family, - "spm_unit": result.spm_unit, - "tax_unit": result.tax_unit, - "household": result.household, + "person": extract_entity_outputs( + "person", output_data.person, n_people + ), + "marital_unit": extract_entity_outputs( + "marital_unit", + output_data.marital_unit, + len(output_data.marital_unit), + ), + "family": extract_entity_outputs( + "family", + output_data.family, + len(output_data.family), + ), + "spm_unit": extract_entity_outputs( + "spm_unit", + output_data.spm_unit, + len(output_data.spm_unit), + ), + "tax_unit": extract_entity_outputs( + "tax_unit", + output_data.tax_unit, + len(output_data.tax_unit), + ), + "household": extract_entity_outputs( + "household", + output_data.household, + len(output_data.household), + ), } ), "completed_at": datetime.now(timezone.utc), @@ -572,7 +846,6 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N # Save output dataset with logfire.span("save_output_dataset"): - import tempfile from supabase import create_client output_filename = f"output_{simulation_id}.h5" diff --git a/tests/test_household.py b/tests/test_household.py index 3a79fe1..a7248b3 100644 --- a/tests/test_household.py +++ b/tests/test_household.py @@ -77,10 +77,12 @@ def test_with_household_data(self): json={ "tax_benefit_model_name": "policyengine_uk", "people": [{"age": 40, "employment_income": 45000}], - "household": { - "region": "LONDON", - "rent": 1500, - }, + "household": [ + { + "region": "LONDON", + "rent": 1500, + } + ], "year": 2026, }, ) @@ -154,6 +156,114 @@ def test_family_with_children(self): assert len(data["result"]["person"]) == 4 +class TestMultiHousehold: + """Tests for multiple household calculations.""" + + def test_multiple_uk_households(self): + """Test calculation for multiple UK households.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [ + # Person in household 0 + { + "person_id": 0, + "person_benunit_id": 0, + "person_household_id": 0, + "age": 30, + "employment_income": 30000, + }, + # Person in household 1 + { + "person_id": 1, + "person_benunit_id": 1, + "person_household_id": 1, + "age": 45, + "employment_income": 60000, + }, + ], + "benunit": [ + {"benunit_id": 0}, + {"benunit_id": 1}, + ], + "household": [ + {"household_id": 0, "region": "LONDON"}, + {"household_id": 1, "region": "NORTH_EAST"}, + ], + "year": 2026, + }, + ) + assert response.status_code == 200 + job_data = response.json() + data = _poll_job(job_data["job_id"]) + + assert len(data["result"]["person"]) == 2 + assert len(data["result"]["benunit"]) == 2 + assert len(data["result"]["household"]) == 2 + + def test_multiple_us_households(self): + """Test calculation for multiple US households.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [ + # Person in household 0 + { + "person_id": 0, + "person_household_id": 0, + "person_tax_unit_id": 0, + "person_marital_unit_id": 0, + "person_family_id": 0, + "person_spm_unit_id": 0, + "age": 30, + "employment_income": 50000, + }, + # Person in household 1 + { + "person_id": 1, + "person_household_id": 1, + "person_tax_unit_id": 1, + "person_marital_unit_id": 1, + "person_family_id": 1, + "person_spm_unit_id": 1, + "age": 40, + "employment_income": 80000, + }, + ], + "household": [ + {"household_id": 0, "state_fips": 6}, # California + {"household_id": 1, "state_fips": 36}, # New York + ], + "tax_unit": [ + {"tax_unit_id": 0, "state_code": "CA"}, + {"tax_unit_id": 1, "state_code": "NY"}, + ], + "marital_unit": [ + {"marital_unit_id": 0}, + {"marital_unit_id": 1}, + ], + "family": [ + {"family_id": 0}, + {"family_id": 1}, + ], + "spm_unit": [ + {"spm_unit_id": 0}, + {"spm_unit_id": 1}, + ], + "year": 2024, + }, + ) + assert response.status_code == 200 + job_data = response.json() + data = _poll_job(job_data["job_id"]) + + assert len(data["result"]["person"]) == 2 + assert len(data["result"]["household"]) == 2 + assert len(data["result"]["tax_unit"]) == 2 + + class TestValidation: """Tests for request validation.""" From b599ab6f2272543ed2ae3538a0b1ae6882957dd0 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Tue, 13 Jan 2026 10:33:07 +0000 Subject: [PATCH 2/3] fix: route paths and test fixtures for aggregate endpoints - Changed route paths from "/" to "" to avoid trailing slash mismatch - Added simulation_id fixture to create test dependencies - Updated tests to use the fixture instead of random UUIDs Co-Authored-By: Claude --- src/policyengine_api/api/change_aggregates.py | 4 +- src/policyengine_api/api/outputs.py | 4 +- tests/conftest.py | 53 +++++++++++++++++++ tests/test_change_aggregates.py | 18 +++---- tests/test_outputs.py | 13 +++-- 5 files changed, 71 insertions(+), 21 deletions(-) diff --git a/src/policyengine_api/api/change_aggregates.py b/src/policyengine_api/api/change_aggregates.py index c9cb544..f706939 100644 --- a/src/policyengine_api/api/change_aggregates.py +++ b/src/policyengine_api/api/change_aggregates.py @@ -83,7 +83,7 @@ def _trigger_change_aggregate_computation( ) -@router.post("/", response_model=List[ChangeAggregateRead]) +@router.post("", response_model=List[ChangeAggregateRead]) def create_change_aggregates( outputs: List[ChangeAggregateCreate], background_tasks: BackgroundTasks, @@ -128,7 +128,7 @@ def create_change_aggregates( return db_outputs -@router.get("/", response_model=List[ChangeAggregateRead]) +@router.get("", response_model=List[ChangeAggregateRead]) def list_change_aggregates(session: Session = Depends(get_session)): """List all change aggregates.""" outputs = session.exec(select(ChangeAggregate)).all() diff --git a/src/policyengine_api/api/outputs.py b/src/policyengine_api/api/outputs.py index 521b9fa..f87cf62 100644 --- a/src/policyengine_api/api/outputs.py +++ b/src/policyengine_api/api/outputs.py @@ -82,7 +82,7 @@ def _trigger_aggregate_computation( ) -@router.post("/", response_model=List[AggregateOutputRead]) +@router.post("", response_model=List[AggregateOutputRead]) def create_aggregate_outputs( outputs: List[AggregateOutputCreate], background_tasks: BackgroundTasks, @@ -121,7 +121,7 @@ def create_aggregate_outputs( return db_outputs -@router.get("/", response_model=List[AggregateOutputRead]) +@router.get("", response_model=List[AggregateOutputRead]) def list_aggregate_outputs(session: Session = Depends(get_session)): """List all aggregates.""" outputs = session.exec(select(AggregateOutput)).all() diff --git a/tests/conftest.py b/tests/conftest.py index a14c5d7..8be9b3f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ """Pytest fixtures for tests.""" +from uuid import uuid4 + import pytest from fastapi.testclient import TestClient from fastapi_cache import FastAPICache @@ -8,6 +10,13 @@ from sqlmodel.pool import StaticPool from policyengine_api.main import app +from policyengine_api.models import ( + Dataset, + Simulation, + SimulationStatus, + TaxBenefitModel, + TaxBenefitModelVersion, +) from policyengine_api.services.database import get_session @@ -37,3 +46,47 @@ def get_session_override(): client = TestClient(app) yield client app.dependency_overrides.clear() + + +@pytest.fixture(name="simulation_id") +def simulation_fixture(session: Session): + """Create a test simulation with required dependencies.""" + # Create model + model = TaxBenefitModel(name="policyengine_uk", description="UK model") + session.add(model) + session.commit() + session.refresh(model) + + # Create model version + version = TaxBenefitModelVersion( + model_id=model.id, + version="test", + description="Test version", + ) + session.add(version) + session.commit() + session.refresh(version) + + # Create dataset + dataset = Dataset( + name="test_dataset", + description="Test dataset", + filepath="test/path/dataset.h5", + year=2024, + tax_benefit_model_id=model.id, + ) + session.add(dataset) + session.commit() + session.refresh(dataset) + + # Create simulation + simulation = Simulation( + dataset_id=dataset.id, + tax_benefit_model_version_id=version.id, + status=SimulationStatus.COMPLETED, + ) + session.add(simulation) + session.commit() + session.refresh(simulation) + + return str(simulation.id) diff --git a/tests/test_change_aggregates.py b/tests/test_change_aggregates.py index 0b8ef87..f9ef38f 100644 --- a/tests/test_change_aggregates.py +++ b/tests/test_change_aggregates.py @@ -12,14 +12,14 @@ def test_list_change_aggregates_empty(client): assert isinstance(response.json(), list) -def test_create_single_change_aggregate(client): +def test_create_single_change_aggregate(client, simulation_id): """Create a single change aggregate.""" response = client.post( "/outputs/change-aggregates", json=[ { - "baseline_simulation_id": str(uuid4()), - "reform_simulation_id": str(uuid4()), + "baseline_simulation_id": simulation_id, + "reform_simulation_id": simulation_id, "variable": "net_income", "aggregate_type": "sum", } @@ -33,22 +33,20 @@ def test_create_single_change_aggregate(client): assert data[0]["aggregate_type"] == "sum" -def test_create_multiple_change_aggregates(client): +def test_create_multiple_change_aggregates(client, simulation_id): """Create multiple change aggregates in one request.""" - baseline_id = str(uuid4()) - reform_id = str(uuid4()) response = client.post( "/outputs/change-aggregates", json=[ { - "baseline_simulation_id": baseline_id, - "reform_simulation_id": reform_id, + "baseline_simulation_id": simulation_id, + "reform_simulation_id": simulation_id, "variable": "income_tax", "aggregate_type": "sum", }, { - "baseline_simulation_id": baseline_id, - "reform_simulation_id": reform_id, + "baseline_simulation_id": simulation_id, + "reform_simulation_id": simulation_id, "variable": "benefits", "aggregate_type": "mean", }, diff --git a/tests/test_outputs.py b/tests/test_outputs.py index f4e18d7..2c9be32 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -12,13 +12,13 @@ def test_list_aggregates_empty(client): assert isinstance(response.json(), list) -def test_create_single_aggregate(client): +def test_create_single_aggregate(client, simulation_id): """Create a single aggregate output.""" response = client.post( "/outputs/aggregates", json=[ { - "simulation_id": str(uuid4()), + "simulation_id": simulation_id, "variable": "net_income", "aggregate_type": "sum", } @@ -32,24 +32,23 @@ def test_create_single_aggregate(client): assert data[0]["aggregate_type"] == "sum" -def test_create_multiple_aggregates(client): +def test_create_multiple_aggregates(client, simulation_id): """Create multiple aggregate outputs in one request.""" - sim_id = str(uuid4()) response = client.post( "/outputs/aggregates", json=[ { - "simulation_id": sim_id, + "simulation_id": simulation_id, "variable": "income_tax", "aggregate_type": "sum", }, { - "simulation_id": sim_id, + "simulation_id": simulation_id, "variable": "household_count", "aggregate_type": "count", }, { - "simulation_id": sim_id, + "simulation_id": simulation_id, "variable": "mean_income", "aggregate_type": "mean", }, From dd433333cdf03079d06b6ff325404b9e1c9ad0a1 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Tue, 13 Jan 2026 10:35:37 +0000 Subject: [PATCH 3/3] fix: mock Modal functions in aggregate tests --- tests/test_change_aggregates.py | 13 +++++++++++-- tests/test_outputs.py | 13 +++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/test_change_aggregates.py b/tests/test_change_aggregates.py index f9ef38f..8f9bb0e 100644 --- a/tests/test_change_aggregates.py +++ b/tests/test_change_aggregates.py @@ -1,5 +1,6 @@ """Tests for change aggregate endpoints.""" +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest @@ -12,8 +13,12 @@ def test_list_change_aggregates_empty(client): assert isinstance(response.json(), list) -def test_create_single_change_aggregate(client, simulation_id): +@patch("policyengine_api.api.change_aggregates.modal.Function") +def test_create_single_change_aggregate(mock_modal_fn, client, simulation_id): """Create a single change aggregate.""" + mock_fn = MagicMock() + mock_modal_fn.from_name.return_value = mock_fn + response = client.post( "/outputs/change-aggregates", json=[ @@ -33,8 +38,12 @@ def test_create_single_change_aggregate(client, simulation_id): assert data[0]["aggregate_type"] == "sum" -def test_create_multiple_change_aggregates(client, simulation_id): +@patch("policyengine_api.api.change_aggregates.modal.Function") +def test_create_multiple_change_aggregates(mock_modal_fn, client, simulation_id): """Create multiple change aggregates in one request.""" + mock_fn = MagicMock() + mock_modal_fn.from_name.return_value = mock_fn + response = client.post( "/outputs/change-aggregates", json=[ diff --git a/tests/test_outputs.py b/tests/test_outputs.py index 2c9be32..cf0b145 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -1,5 +1,6 @@ """Tests for aggregate outputs endpoints.""" +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest @@ -12,8 +13,12 @@ def test_list_aggregates_empty(client): assert isinstance(response.json(), list) -def test_create_single_aggregate(client, simulation_id): +@patch("policyengine_api.api.outputs.modal.Function") +def test_create_single_aggregate(mock_modal_fn, client, simulation_id): """Create a single aggregate output.""" + mock_fn = MagicMock() + mock_modal_fn.from_name.return_value = mock_fn + response = client.post( "/outputs/aggregates", json=[ @@ -32,8 +37,12 @@ def test_create_single_aggregate(client, simulation_id): assert data[0]["aggregate_type"] == "sum" -def test_create_multiple_aggregates(client, simulation_id): +@patch("policyengine_api.api.outputs.modal.Function") +def test_create_multiple_aggregates(mock_modal_fn, client, simulation_id): """Create multiple aggregate outputs in one request.""" + mock_fn = MagicMock() + mock_modal_fn.from_name.return_value = mock_fn + response = client.post( "/outputs/aggregates", json=[