Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 2 additions & 30 deletions src/policyengine_api/api/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
ReportStatus,
Simulation,
SimulationStatus,
TaxBenefitModel,
TaxBenefitModelVersion,
)
from policyengine_api.services.database import get_session
from policyengine_api.services.tax_benefit_models import get_latest_model_version


def get_traceparent() -> str | None:
Expand Down Expand Up @@ -110,33 +109,6 @@ class EconomicImpactResponse(BaseModel):
program_statistics: list[ProgramStatisticsRead] | None = None


def _get_model_version(
tax_benefit_model_name: str, session: Session
) -> TaxBenefitModelVersion:
"""Get the latest tax benefit model version."""
model_name = tax_benefit_model_name.replace("_", "-")

model = session.exec(
select(TaxBenefitModel).where(TaxBenefitModel.name == model_name)
).first()
if not model:
raise HTTPException(
status_code=404, detail=f"Tax benefit model {model_name} not found"
)

version = session.exec(
select(TaxBenefitModelVersion)
.where(TaxBenefitModelVersion.model_id == model.id)
.order_by(TaxBenefitModelVersion.created_at.desc())
).first()
if not version:
raise HTTPException(
status_code=404, detail=f"No version found for model {model_name}"
)

return version


def _get_deterministic_simulation_id(
dataset_id: UUID,
model_version_id: UUID,
Expand Down Expand Up @@ -576,7 +548,7 @@ def economic_impact(
)

# Get model version
model_version = _get_model_version(request.tax_benefit_model_name, session)
model_version = get_latest_model_version(request.tax_benefit_model_name, session)

# Get or create simulations
baseline_sim = _get_or_create_simulation(
Expand Down
22 changes: 21 additions & 1 deletion src/policyengine_api/api/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, or_, select

from policyengine_api.models import ParameterValue, ParameterValueRead
from policyengine_api.models import Parameter, ParameterValue, ParameterValueRead
from policyengine_api.services.database import get_session
from policyengine_api.services.tax_benefit_models import resolve_model_version_id

router = APIRouter(prefix="/parameter-values", tags=["parameter-values"])

Expand All @@ -23,6 +24,8 @@ def list_parameter_values(
parameter_id: UUID | None = None,
policy_id: UUID | None = None,
current: bool = False,
tax_benefit_model_name: str | None = None,
tax_benefit_model_version_id: UUID | None = None,
skip: int = 0,
limit: int = 100,
session: Session = Depends(get_session),
Expand All @@ -37,6 +40,12 @@ def list_parameter_values(
policy_id: Filter by a specific policy reform.
current: If true, only return values that are currently in effect
(start_date <= now and (end_date is null or end_date > now)).
tax_benefit_model_name: Filter by country model name.
Use "policyengine-uk" for UK parameter values.
Use "policyengine-us" for US parameter values.
When specified without version_id, returns values from the latest version.
tax_benefit_model_version_id: Filter by specific model version UUID.
Takes precedence over tax_benefit_model_name if both are provided.
"""
query = select(ParameterValue)

Expand All @@ -46,6 +55,17 @@ def list_parameter_values(
if policy_id:
query = query.where(ParameterValue.policy_id == policy_id)

# Resolve version ID from either explicit ID or model name (defaults to latest)
version_id = resolve_model_version_id(
tax_benefit_model_name, tax_benefit_model_version_id, session
)

if version_id:
# Join through Parameter to filter by model version
query = query.join(Parameter).where(
Parameter.tax_benefit_model_version_id == version_id
)

if current:
now = datetime.now(timezone.utc)
query = query.where(
Expand Down
23 changes: 13 additions & 10 deletions src/policyengine_api/api/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
from policyengine_api.models import (
Parameter,
ParameterRead,
TaxBenefitModel,
TaxBenefitModelVersion,
)
from policyengine_api.services.database import get_session
from policyengine_api.services.tax_benefit_models import resolve_model_version_id

router = APIRouter(prefix="/parameters", tags=["parameters"])

Expand All @@ -28,6 +27,7 @@ def list_parameters(
limit: int = 100,
search: str | None = None,
tax_benefit_model_name: str | None = None,
tax_benefit_model_version_id: UUID | None = None,
session: Session = Depends(get_session),
):
"""List available parameters with pagination and search.
Expand All @@ -37,19 +37,22 @@ def list_parameters(

Args:
search: Filter by parameter name, label, or description.
tax_benefit_model_name: Filter by country model.
tax_benefit_model_name: Filter by country model name.
Use "policyengine-uk" for UK parameters.
Use "policyengine-us" for US parameters.
When specified without version_id, returns parameters from the latest version.
tax_benefit_model_version_id: Filter by specific model version UUID.
Takes precedence over tax_benefit_model_name if both are provided.
"""
query = select(Parameter)

# Filter by tax benefit model name (country)
if tax_benefit_model_name:
query = (
query.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == tax_benefit_model_name)
)
# Resolve version ID from either explicit ID or model name (defaults to latest)
version_id = resolve_model_version_id(
tax_benefit_model_name, tax_benefit_model_version_id, session
)

if version_id:
query = query.where(Parameter.tax_benefit_model_version_id == version_id)

if search:
# Case-insensitive search using ILIKE
Expand Down
23 changes: 13 additions & 10 deletions src/policyengine_api/api/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
from sqlmodel import Session, select

from policyengine_api.models import (
TaxBenefitModel,
TaxBenefitModelVersion,
Variable,
VariableRead,
)
from policyengine_api.services.database import get_session
from policyengine_api.services.tax_benefit_models import resolve_model_version_id

router = APIRouter(prefix="/variables", tags=["variables"])

Expand All @@ -28,6 +27,7 @@ def list_variables(
limit: int = 100,
search: str | None = None,
tax_benefit_model_name: str | None = None,
tax_benefit_model_version_id: UUID | None = None,
session: Session = Depends(get_session),
):
"""List available variables with pagination and search.
Expand All @@ -38,19 +38,22 @@ def list_variables(

Args:
search: Filter by variable name, label, or description.
tax_benefit_model_name: Filter by country model.
tax_benefit_model_name: Filter by country model name.
Use "policyengine-uk" for UK variables.
Use "policyengine-us" for US variables.
When specified without version_id, returns variables from the latest version.
tax_benefit_model_version_id: Filter by specific model version UUID.
Takes precedence over tax_benefit_model_name if both are provided.
"""
query = select(Variable)

# Filter by tax benefit model name (country)
if tax_benefit_model_name:
query = (
query.join(TaxBenefitModelVersion)
.join(TaxBenefitModel)
.where(TaxBenefitModel.name == tax_benefit_model_name)
)
# Resolve version ID from either explicit ID or model name (defaults to latest)
version_id = resolve_model_version_id(
tax_benefit_model_name, tax_benefit_model_version_id, session
)

if version_id:
query = query.where(Variable.tax_benefit_model_version_id == version_id)

if search:
# Case-insensitive search using ILIKE
Expand Down
107 changes: 107 additions & 0 deletions src/policyengine_api/services/tax_benefit_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Tax benefit model utilities.

Shared utilities for working with tax benefit models and versions.
"""

from uuid import UUID

from fastapi import HTTPException
from sqlmodel import Session, select

from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion


def get_latest_model_version(
tax_benefit_model_name: str, session: Session
) -> TaxBenefitModelVersion:
"""Get the latest tax benefit model version for a given model name.

Args:
tax_benefit_model_name: The model name (e.g., "policyengine-us" or "policyengine_us").
session: Database session.

Returns:
The latest TaxBenefitModelVersion for the model.

Raises:
HTTPException: If model or version not found.
"""
# Normalize model name (allow underscores or hyphens)
model_name = tax_benefit_model_name.replace("_", "-")

model = session.exec(
select(TaxBenefitModel).where(TaxBenefitModel.name == model_name)
).first()
if not model:
raise HTTPException(
status_code=404, detail=f"Tax benefit model {model_name} not found"
)

version = session.exec(
select(TaxBenefitModelVersion)
.where(TaxBenefitModelVersion.model_id == model.id)
.order_by(TaxBenefitModelVersion.created_at.desc())
).first()
if not version:
raise HTTPException(
status_code=404, detail=f"No version found for model {model_name}"
)

return version


def get_model_version_by_id(
version_id: UUID, session: Session
) -> TaxBenefitModelVersion:
"""Get a specific tax benefit model version by ID.

Args:
version_id: The UUID of the model version.
session: Database session.

Returns:
The TaxBenefitModelVersion with the given ID.

Raises:
HTTPException: If version not found.
"""
version = session.get(TaxBenefitModelVersion, version_id)
if not version:
raise HTTPException(
status_code=404, detail=f"Tax benefit model version {version_id} not found"
)
return version


def resolve_model_version_id(
tax_benefit_model_name: str | None,
tax_benefit_model_version_id: UUID | None,
session: Session,
) -> UUID | None:
"""Resolve the model version ID from either explicit ID or model name.

If version_id is provided, validates and returns it.
If only model_name is provided, returns the latest version ID for that model.
If neither is provided, returns None.

Args:
tax_benefit_model_name: Optional model name to get latest version for.
tax_benefit_model_version_id: Optional explicit version ID.
session: Database session.

Returns:
The resolved version ID, or None if no filtering requested.

Raises:
HTTPException: If specified version/model not found.
"""
if tax_benefit_model_version_id:
# Validate the version exists
version = get_model_version_by_id(tax_benefit_model_version_id, session)
return version.id
elif tax_benefit_model_name:
# Get the latest version for this model
version = get_latest_model_version(tax_benefit_model_name, session)
return version.id
else:
return None
Loading